Commit 12d73a82 authored by NVShreyas's avatar NVShreyas Committed by GitHub
Browse files

Updates to support DS R1 in TRTLLM example (#301)

parent e584e96f
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import argparse import argparse
import json import json
import os import os
from pathlib import Path
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
# Define the expected keys for each config # Define the expected keys for each config
...@@ -65,6 +66,8 @@ def _get_llm_args(args_dict): ...@@ -65,6 +66,8 @@ def _get_llm_args(args_dict):
} }
if "model" not in llm_engine_args: if "model" not in llm_engine_args:
raise ValueError("Model name is required in the TRT-LLM engine config.") raise ValueError("Model name is required in the TRT-LLM engine config.")
if os.path.exists(llm_engine_args["model"]):
llm_engine_args["model"] = Path(llm_engine_args["model"])
return (pytorch_config_args, llm_engine_args) return (pytorch_config_args, llm_engine_args)
......
...@@ -13,22 +13,22 @@ ...@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost hostname: localhost
port: 8000 port: 8000
backend: "pytorch"
context_servers: context_servers:
num_instances: 2 num_instances: 2
gpu_fraction: 0.25 tensor_parallel_size: 2
tp_size: 1 kv_cache_config:
pp_size: 1 free_gpu_memory_fraction: 0.2
urls: urls:
- "localhost:8001" - "localhost:8001"
- "localhost:8002" - "localhost:8002"
generation_servers: generation_servers:
num_instances: 1 num_instances: 2
gpu_fraction: 0.25 tensor_parallel_size: 2
tp_size: 1 kv_cache_config:
pp_size: 1 free_gpu_memory_fraction: 0.2
urls: urls:
- "localhost:8002" - "localhost:8003"
- "localhost:8004"
\ No newline at end of file
...@@ -104,15 +104,17 @@ class TensorrtLLMEngine: ...@@ -104,15 +104,17 @@ class TensorrtLLMEngine:
None, None,
lambda: LLM( lambda: LLM(
**self.llm_engine_args, **self.llm_engine_args,
tensor_parallel_size=self.server_config.other_args["tp_size"], tensor_parallel_size=self.server_config.other_args.get(
pipeline_parallel_size=self.server_config.other_args["pp_size"], "tensor_parallel_size", 1
),
pipeline_parallel_size=self.server_config.other_args.get(
"pipeline_parallel_size", 1
),
gpus_per_node=None, gpus_per_node=None,
trust_remote_code=True, trust_remote_code=True,
_mpi_session=self.mpi_session, _mpi_session=self.mpi_session,
kv_cache_config=KvCacheConfig( kv_cache_config=KvCacheConfig(
free_gpu_memory_fraction=self.server_config.other_args[ **self.server_config.other_args.get("kv_cache_config", {})
"gpu_fraction"
]
), ),
pytorch_backend_config=pytorch_config, pytorch_backend_config=pytorch_config,
backend="pytorch", backend="pytorch",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment