parser.py 1.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse


def parse_tensorrt_llm_args(
    config_args,
21
) -> argparse.Namespace:
22
23
    parser = argparse.ArgumentParser(description="A TensorRT-LLM Worker parser")
    parser.add_argument(
24
25
26
27
        "--extra-engine-args",
        type=str,
        default="",
        help="Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.",
28
    )
29
    parser.add_argument(
30
        "--model-path",
31
32
        type=str,
        default=None,
33
        help="Path to disk model or HuggingFace model identifier to load.",
34
    )
35
    parser.add_argument(
36
        "--served_model_name",
37
        type=str,
38
        help="Name to serve the model under.",
39
40
41
42
43
44
45
46
    )
    parser.add_argument(
        "--router",
        type=str,
        choices=["random", "round-robin", "kv"],
        default="random",
        help="Router type to use for scheduling requests to workers",
    )
47

48
    parser.add_argument(
49
        "--kv-block-size",
50
51
52
53
54
55
        type=int,
        default=32,
        help="Number of tokens per KV block in TRTLLM worker. Default is 32 for pytorch backend.",
    )

    parser.add_argument(
56
        "--enable-disagg",
57
        action="store_true",
58
        help="Enable remote prefill for the worker",
59
60
    )

61
62
    args = parser.parse_args(config_args)
    return args