Commit 12d73a82 authored by NVShreyas's avatar NVShreyas Committed by GitHub
Browse files

Updates to support DS R1 in TRTLLM example (#301)

parent e584e96f
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
......@@ -16,6 +16,7 @@
import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, Tuple
# Define the expected keys for each config
......@@ -65,6 +66,8 @@ def _get_llm_args(args_dict):
}
if "model" not in llm_engine_args:
raise ValueError("Model name is required in the TRT-LLM engine config.")
if os.path.exists(llm_engine_args["model"]):
llm_engine_args["model"] = Path(llm_engine_args["model"])
return (pytorch_config_args, llm_engine_args)
......
......@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
context_servers:
num_instances: 2
gpu_fraction: 0.25
tp_size: 1
pp_size: 1
tensor_parallel_size: 2
kv_cache_config:
free_gpu_memory_fraction: 0.2
urls:
- "localhost:8001"
- "localhost:8002"
generation_servers:
num_instances: 1
gpu_fraction: 0.25
tp_size: 1
pp_size: 1
num_instances: 2
tensor_parallel_size: 2
kv_cache_config:
free_gpu_memory_fraction: 0.2
urls:
- "localhost:8002"
- "localhost:8003"
- "localhost:8004"
\ No newline at end of file
......@@ -104,15 +104,17 @@ class TensorrtLLMEngine:
None,
lambda: LLM(
**self.llm_engine_args,
tensor_parallel_size=self.server_config.other_args["tp_size"],
pipeline_parallel_size=self.server_config.other_args["pp_size"],
tensor_parallel_size=self.server_config.other_args.get(
"tensor_parallel_size", 1
),
pipeline_parallel_size=self.server_config.other_args.get(
"pipeline_parallel_size", 1
),
gpus_per_node=None,
trust_remote_code=True,
_mpi_session=self.mpi_session,
kv_cache_config=KvCacheConfig(
free_gpu_memory_fraction=self.server_config.other_args[
"gpu_fraction"
]
**self.server_config.other_args.get("kv_cache_config", {})
),
pytorch_backend_config=pytorch_config,
backend="pytorch",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment