Unverified Commit 7e3b3fab authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

fix: Add default configs in LLMAPI. Fixes OOM issues (#2198)

parent 625578c3
...@@ -8,8 +8,16 @@ import sys ...@@ -8,8 +8,16 @@ import sys
import uvloop import uvloop
from tensorrt_llm import SamplingParams from tensorrt_llm import SamplingParams
from tensorrt_llm.llmapi import (
BuildConfig,
CapacitySchedulerPolicy,
DynamicBatchConfig,
KvCacheConfig,
SchedulerConfig,
)
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from torch.cuda import device_count
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
...@@ -84,12 +92,51 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -84,12 +92,51 @@ async def init(runtime: DistributedRuntime, config: Config):
# Convert model path to Path object if it's a local path, otherwise keep as string # Convert model path to Path object if it's a local path, otherwise keep as string
model_path = str(config.model_path) model_path = str(config.model_path)
if config.gpus_per_node is None:
gpus_per_node = device_count()
if gpus_per_node == 0:
raise ValueError("No GPU devices found on the node")
else:
gpus_per_node = config.gpus_per_node
build_config = BuildConfig(
max_batch_size=config.max_batch_size,
max_num_tokens=config.max_num_tokens,
max_beam_width=config.max_beam_width,
max_seq_len=config.max_seq_len,
)
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=config.free_gpu_memory_fraction
)
dynamic_batch_config = DynamicBatchConfig(
enable_batch_size_tuning=True,
enable_max_num_tokens_tuning=False,
dynamic_batch_moving_average_window=128,
)
scheduler_config = SchedulerConfig(
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
dynamic_batch_config=dynamic_batch_config,
)
arg_map = { arg_map = {
"model": model_path, "model": model_path,
"scheduler_config": scheduler_config,
"tensor_parallel_size": config.tensor_parallel_size, "tensor_parallel_size": config.tensor_parallel_size,
"pipeline_parallel_size": config.pipeline_parallel_size,
"moe_expert_parallel_size": config.expert_parallel_size,
"backend": "pytorch", "backend": "pytorch",
"skip_tokenizer_init": True, "skip_tokenizer_init": True,
"build_config": build_config,
"kv_cache_config": kv_cache_config,
"gpus_per_node": gpus_per_node,
"max_num_tokens": config.max_num_tokens,
"max_seq_len": config.max_seq_len,
"max_beam_width": config.max_beam_width,
"max_batch_size": config.max_batch_size,
} }
if config.extra_engine_args != "": if config.extra_engine_args != "":
# TODO: Support extra engine args from json file as well. # TODO: Support extra engine args from json file as well.
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args) arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import argparse import argparse
from typing import Optional from typing import Optional
from tensorrt_llm.llmapi import BuildConfig
from dynamo.trtllm.request_handlers.handler_base import ( from dynamo.trtllm.request_handlers.handler_base import (
DisaggregationMode, DisaggregationMode,
DisaggregationStrategy, DisaggregationStrategy,
...@@ -27,8 +29,16 @@ class Config: ...@@ -27,8 +29,16 @@ class Config:
self.model_path: str = "" self.model_path: str = ""
self.served_model_name: Optional[str] = None self.served_model_name: Optional[str] = None
self.tensor_parallel_size: int = 1 self.tensor_parallel_size: int = 1
self.pipeline_parallel_size: int = 1
self.expert_parallel_size: Optional[int] = None
self.kv_block_size: int = 32 self.kv_block_size: int = 32
self.migration_limit: int = 0 self.migration_limit: int = 0
self.gpus_per_node: Optional[int] = None
self.max_batch_size: int = BuildConfig.max_batch_size
self.max_num_tokens: int = BuildConfig.max_num_tokens
self.max_seq_len: int = BuildConfig.max_seq_len
self.max_beam_width: int = BuildConfig.max_beam_width
self.free_gpu_memory_fraction: Optional[float] = None
self.extra_engine_args: str = "" self.extra_engine_args: str = ""
self.publish_events_and_metrics: bool = False self.publish_events_and_metrics: bool = False
self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE
...@@ -45,7 +55,15 @@ class Config: ...@@ -45,7 +55,15 @@ class Config:
f"model_path={self.model_path}, " f"model_path={self.model_path}, "
f"served_model_name={self.served_model_name}, " f"served_model_name={self.served_model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, " f"tensor_parallel_size={self.tensor_parallel_size}, "
f"pipeline_parallel_size={self.pipeline_parallel_size}, "
f"expert_parallel_size={self.expert_parallel_size}, "
f"kv_block_size={self.kv_block_size}, " f"kv_block_size={self.kv_block_size}, "
f"gpus_per_node={self.gpus_per_node}, "
f"max_batch_size={self.max_batch_size}, "
f"max_num_tokens={self.max_num_tokens}, "
f"max_seq_len={self.max_seq_len}, "
f"max_beam_width={self.max_beam_width}, "
f"free_gpu_memory_fraction={self.free_gpu_memory_fraction}, "
f"extra_engine_args={self.extra_engine_args}, " f"extra_engine_args={self.extra_engine_args}, "
f"migration_limit={self.migration_limit}, " f"migration_limit={self.migration_limit}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, " f"publish_events_and_metrics={self.publish_events_and_metrics}, "
...@@ -108,8 +126,21 @@ def cmd_line_args(): ...@@ -108,8 +126,21 @@ def cmd_line_args():
help="Name to serve the model under. Defaults to deriving it from model path.", help="Name to serve the model under. Defaults to deriving it from model path.",
) )
parser.add_argument( parser.add_argument(
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use." "--tensor-parallel-size", type=int, default=1, help="Tensor parallelism size."
)
parser.add_argument(
"--pipeline-parallel-size",
type=int,
default=None,
help="Pipeline parallelism size.",
)
parser.add_argument(
"--expert-parallel-size",
type=int,
default=None,
help="expert parallelism size.",
) )
# IMPORTANT: We should ideally not expose this to users. We should be able to # IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine. # query the block size from the TRTLLM engine.
parser.add_argument( parser.add_argument(
...@@ -121,6 +152,43 @@ def cmd_line_args(): ...@@ -121,6 +152,43 @@ def cmd_line_args():
default=0, default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.", help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
) )
parser.add_argument(
"--gpus-per-node",
type=int,
default=None,
help="Number of GPUs per node. If not provided, will be inferred from the environment.",
)
parser.add_argument(
"--max-batch-size",
type=int,
default=BuildConfig.max_batch_size,
help="Maximum number of requests that the engine can schedule.",
)
parser.add_argument(
"--max-num-tokens",
type=int,
default=BuildConfig.max_num_tokens,
help="Maximum number of batched input tokens after padding is removed in each batch.",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=BuildConfig.max_seq_len,
help="Maximum total length of one request, including prompt and outputs. "
"If unspecified, the value is deduced from the model config.",
)
parser.add_argument(
"--max-beam-width",
type=int,
default=BuildConfig.max_beam_width,
help="Maximum number of beams for beam search decoding.",
)
parser.add_argument(
"--free-gpu-memory-fraction",
type=float,
default=None,
help="Free GPU memory fraction reserved for KV Cache, after allocating model weights and buffers.",
)
parser.add_argument( parser.add_argument(
"--extra-engine-args", "--extra-engine-args",
...@@ -195,6 +263,18 @@ def cmd_line_args(): ...@@ -195,6 +263,18 @@ def cmd_line_args():
config.next_endpoint = args.next_endpoint config.next_endpoint = args.next_endpoint
config.tensor_parallel_size = args.tensor_parallel_size config.tensor_parallel_size = args.tensor_parallel_size
if args.pipeline_parallel_size is not None:
config.pipeline_parallel_size = args.pipeline_parallel_size
if args.expert_parallel_size is not None:
config.expert_parallel_size = args.expert_parallel_size
if args.gpus_per_node is not None:
config.gpus_per_node = args.gpus_per_node
if args.free_gpu_memory_fraction is not None:
config.free_gpu_memory_fraction = args.free_gpu_memory_fraction
config.max_batch_size = args.max_batch_size
config.max_num_tokens = args.max_num_tokens
config.max_seq_len = args.max_seq_len
config.max_beam_width = args.max_beam_width
config.kv_block_size = args.kv_block_size config.kv_block_size = args.kv_block_size
config.migration_limit = args.migration_limit config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment