Unverified Commit bb0e8a32 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up server args (#8161)

parent 1b427dae
/3rdparty/amd @HaiShaw /3rdparty/amd @HaiShaw
/docker @zhyncs @HaiShaw @ByronHsu /docker @zhyncs @HaiShaw @ByronHsu
/docs @zhaochenyang20 /docs @zhaochenyang20
/python/sglang/lang @merrymercy @Ying1123 @hnyls2002 @ByronHsu /python/sglang/lang @merrymercy @Ying1123 @hnyls2002
/python/sglang/srt @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu /python/sglang/srt @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu
/python/sglang/srt/constrained @hnyls2002 /python/sglang/srt/constrained @hnyls2002
/python/sglang/srt/disaggregation @hnyls2002 @ByronHsu /python/sglang/srt/disaggregation @ByronHsu @hnyls2002
/python/sglang/srt/distributed @yizhang2077 /python/sglang/srt/distributed @yizhang2077
/python/sglang/srt/entrypoints @zhaochenyang20 /python/sglang/srt/entrypoints @zhaochenyang20 @CatherineSue
/python/sglang/srt/entrypoints/openai @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu @CatherineSue /python/sglang/srt/eplb @fzyzcjy
/python/sglang/srt/function_call @CatherineSue
/python/sglang/srt/layers @merrymercy @Ying1123 @zhyncs @ispobock @HaiShaw @ch-wan @BBuf /python/sglang/srt/layers @merrymercy @Ying1123 @zhyncs @ispobock @HaiShaw @ch-wan @BBuf
/python/sglang/srt/lora @Ying1123 @Fridge003 /python/sglang/srt/lora @Ying1123 @Fridge003
/python/sglang/srt/managers @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann /python/sglang/srt/managers @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann
/python/sglang/srt/mem_cache @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann /python/sglang/srt/mem_cache @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann
/python/sglang/srt/model_executor @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock /python/sglang/srt/model_executor @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock
/python/sglang/srt/models @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu @zhaochenyang20 /python/sglang/srt/models @zhyncs @ispobock @ByronHsu @zhaochenyang20
/python/sglang/srt/sampling @merrymercy @hnyls2002
/python/sglang/srt/speculative @Ying1123 @merrymercy @rkooo567 @kssteven418
/python/sglang/srt/multimodal @mickqian @JustinTong0323 /python/sglang/srt/multimodal @mickqian @JustinTong0323
/test/lang @merrymercy @Ying1123 @ByronHsu /python/sglang/srt/sampling @hnyls2002
/python/sglang/srt/speculative @Ying1123 @merrymercy @rkooo567 @kssteven418
/test/lang @merrymercy @Ying1123
/test/srt @merrymercy @Ying1123 @zhyncs /test/srt @merrymercy @Ying1123 @zhyncs
/sgl-router @ByronHsu @Ying1123 @slin1237 /sgl-router @ByronHsu @slin1237
/sgl-kernel @zhyncs @ispobock @HandH1998 @BBuf @yizhang2077 @merrymercy @yinfan98 @HaiShaw /sgl-kernel @zhyncs @ispobock @HandH1998 @BBuf @yizhang2077 @merrymercy @yinfan98 @HaiShaw
...@@ -51,7 +51,7 @@ You can find all arguments by `python3 -m sglang.launch_server --help` ...@@ -51,7 +51,7 @@ You can find all arguments by `python3 -m sglang.launch_server --help`
Please consult the documentation below and [server_args.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py) to learn more about the arguments you may provide when launching a server. Please consult the documentation below and [server_args.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py) to learn more about the arguments you may provide when launching a server.
## Model, processor and tokenizer ## Model and tokenizer
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
...@@ -61,20 +61,30 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -61,20 +61,30 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--skip-tokenizer-init` | If set, skip init tokenizer and pass input_ids in generate request. | False | | `--skip-tokenizer-init` | If set, skip init tokenizer and pass input_ids in generate request. | False |
| `--load-format` | The format of the model weights to load. 'auto' will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. 'pt' will load the weights in the pytorch bin format. 'safetensors' will load the weights in the safetensors format. 'npcache' will load the weights in pytorch format and store a numpy cache to speed up the loading. 'dummy' will initialize the weights with random values, which is mainly for profiling. 'gguf' will load the weights in the gguf format. 'bitsandbytes' will load the weights using bitsandbytes quantization. 'layered' loads weights layer by layer so that one can quantize a layer before loading another to make the peak memory envelope smaller. | auto | | `--load-format` | The format of the model weights to load. 'auto' will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. 'pt' will load the weights in the pytorch bin format. 'safetensors' will load the weights in the safetensors format. 'npcache' will load the weights in pytorch format and store a numpy cache to speed up the loading. 'dummy' will initialize the weights with random values, which is mainly for profiling. 'gguf' will load the weights in the gguf format. 'bitsandbytes' will load the weights using bitsandbytes quantization. 'layered' loads weights layer by layer so that one can quantize a layer before loading another to make the peak memory envelope smaller. | auto |
| `--trust-remote-code` | Whether or not to allow for custom models defined on the Hub in their own modeling files. | False | | `--trust-remote-code` | Whether or not to allow for custom models defined on the Hub in their own modeling files. | False |
| `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto |
| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto |
| `--quantization` | The quantization method. | None |
| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None |
| `--context-length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). | None | | `--context-length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). | None |
| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None |
| `--served-model-name` | Override the model name returned by the v1/models endpoint in OpenAI API server. | None |
| `--chat-template` | The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. | None |
| `--completion-template` | The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently. | None |
| `--is-embedding` | Whether to use a CausalLM as an embedding model. | False | | `--is-embedding` | Whether to use a CausalLM as an embedding model. | False |
| `--enable-multimodal` | Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen. | None | | `--enable-multimodal` | Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen. | None |
| `--revision` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | None | | `--revision` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | None |
| `--impl` | Which implementation of the model to use. 'auto' will try to use the SGLang implementation if it exists and fall back to the Transformers implementation if no SGLang implementation is available. 'sglang' will use the SGLang model implementation. 'transformers' will use the Transformers model implementation. | auto | | `--model-impl` | Which implementation of the model to use. 'auto' will try to use the SGLang implementation if it exists and fall back to the Transformers implementation if no SGLang implementation is available. 'sglang' will use the SGLang model implementation. 'transformers' will use the Transformers model implementation. | auto |
## HTTP server
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--host` | The host address for the server. | 127.0.0.1 |
| `--port` | The port number for the server. | 30000 |
| `--skip-server-warmup` | If set, skip the server warmup process. | False |
| `--warmups` | Warmup configurations. | None |
| `--nccl-port` | The port for NCCL initialization. | None |
## Quantization and data type
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto |
| `--quantization` | The quantization method. | None |
| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None |
| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto |
## Memory and scheduling ## Memory and scheduling
...@@ -90,13 +100,13 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -90,13 +100,13 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--cpu-offload-gb` | How many GBs of RAM to reserve for CPU offloading. | 0 | | `--cpu-offload-gb` | How many GBs of RAM to reserve for CPU offloading. | 0 |
| `--page-size` | The number of tokens in a page. | 1 | | `--page-size` | The number of tokens in a page. | 1 |
## Runtime options
## Other runtime options
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--tensor-parallel-size` or `--tp-size` | The tensor parallelism size. | 1 | | `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None |
| `--pipeline-parallel-size` or `--pp-size` | The pipeline parallelism size. | 1 | | `--tp-size` | The tensor parallelism size. | 1 |
| `--pp-size` | The pipeline parallelism size. | 1 |
| `--max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | None | | `--max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | None |
| `--stream-interval` | The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher. | 1 | | `--stream-interval` | The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher. | 1 |
| `--stream-output` | Whether to output as a sequence of disjoint segments. | False | | `--stream-output` | Whether to output as a sequence of disjoint segments. | False |
...@@ -132,6 +142,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -132,6 +142,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--api-key` | Set API key of the server. It is also used in the OpenAI API compatible server. | None | | `--api-key` | Set API key of the server. It is also used in the OpenAI API compatible server. | None |
| `--served-model-name` | Override the model name returned by the v1/models endpoint in OpenAI API server. | None |
| `--chat-template` | The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. | None |
| `--completion-template` | The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently. | None |
| `--file-storage-path` | The path of the file storage in backend. | sglang_storage | | `--file-storage-path` | The path of the file storage in backend. | sglang_storage |
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False | | `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
| `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None | | `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
...@@ -141,10 +154,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -141,10 +154,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--data-parallel-size` or `--dp-size` | The data parallelism size. | 1 | | `--dp-size` | The data parallelism size. | 1 |
| `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin | | `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin |
## Multi-node distributed serving ## Multi-node distributed serving
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
...@@ -153,7 +165,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -153,7 +165,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--nnodes` | The number of nodes. | 1 | | `--nnodes` | The number of nodes. | 1 |
| `--node-rank` | The node rank. | 0 | | `--node-rank` | The node rank. | 0 |
## Model override args ## Model override args in JSON
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
...@@ -164,11 +176,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -164,11 +176,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
## Kernel backend ## Kernel backend
...@@ -196,9 +208,10 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -196,9 +208,10 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--expert-parallel-size` or `--ep-size` | The expert parallelism size. | 1 | | `--ep-size` | The expert parallelism size. | 1 |
| `--enable-ep-moe` | Enabling expert parallelism for moe. The ep size is equal to the tp size. | False | | `--enable-ep-moe` | Enabling expert parallelism for moe. The ep size is equal to the tp size. | False |
| `--enable-deepep-moe` | Enabling DeepEP MoE implementation for EP MoE. | False | | `--enable-deepep-moe` | Enabling DeepEP MoE implementation for EP MoE. | False |
| `--enable-flashinfer-moe` | Enabling Flashinfer MoE implementation. | False |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | None | | `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | None |
...@@ -213,6 +226,17 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -213,6 +226,17 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--deepep-config` | Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path. | None | | `--deepep-config` | Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path. | None |
| `--moe-dense-tp-size` | TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports. | None | | `--moe-dense-tp-size` | TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports. | None |
## Hierarchical cache
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--enable-hierarchical-cache` | Enable hierarchical cache. | False |
| `--hicache-ratio` | The ratio of the size of host KV cache memory pool to the size of device pool. | 2.0 |
| `--hicache-size` | The size of the hierarchical cache. | 0 |
| `--hicache-write-policy` | The write policy for hierarchical cache. | write_through_selective |
| `--hicache-io-backend` | The IO backend for hierarchical cache. | |
| `--hicache-storage-backend` | The storage backend for hierarchical cache. | None |
## Optimization/debug options ## Optimization/debug options
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
...@@ -229,7 +253,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -229,7 +253,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False | | `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False |
| `--enable-mscclpp` | Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL. | False | | `--enable-mscclpp` | Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL. | False |
| `--disable-overlap-schedule` | Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker. | False | | `--disable-overlap-schedule` | Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker. | False |
| `--disable-overlap-cg-plan` | Disable the overlap optimization for cudagraph preparation in eagle verify. | False |
| `--enable-mixed-chunk` | Enabling mixing prefill and decode in a batch when using chunked prefill. | False | | `--enable-mixed-chunk` | Enabling mixing prefill and decode in a batch when using chunked prefill. | False |
| `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False | | `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False |
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False | | `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
...@@ -246,24 +269,43 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -246,24 +269,43 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False | | `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False |
| `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False | | `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False |
| `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False | | `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False |
| `--enable-hierarchical-cache` | Enable hierarchical cache. | False | | `--flashinfer-mla-disable-ragged` | Disable ragged processing in Flashinfer MLA. | False |
| `--hicache-ratio` | The ratio of the size of host KV cache memory pool to the size of device pool. | 2.0 | | `--disable-shared-experts-fusion` | Disable shared experts fusion. | False |
| `--hicache-size` | The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set. | 0 | | `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False |
| `--hicache-write-policy` | The write policy of hierarchical cache. | write_through_selective | | `--disable-fast-image-processor` | Disable fast image processor. | False |
| `--flashinfer-mla-disable-ragged` | Not using ragged prefill wrapper when running flashinfer mla. | False | | `--enable-return-hidden-states` | Enable returning hidden states. | False |
| `--disable-shared-experts-fusion` | Disable shared experts fusion optimization for deepseek v3/r1. | False | | `--enable-triton-kernel-moe` | Enable Triton kernel for MoE. | False |
| `--disable-chunked-prefix-cache` | Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences. | False |
| `--disable-fast-image-processor` | Adopt base image processor instead of fast image processor. | False | ## Debug tensor dumps
| `--enable-return-hidden-states` | Enable returning hidden states with responses. | False |
| `--warmups` | Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests. | None | | Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None |
| `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None |
| `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False |
| `--debug-tensor-dump-prefill-only` | Enable prefill-only mode for debug tensor dumps. | False |
## PD disaggregation
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--disaggregation-mode` | PD disaggregation mode: "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only). | null |
| `--disaggregation-transfer-backend` | The transfer backend for PD disaggregation. | mooncake |
| `--disaggregation-bootstrap-port` | The bootstrap port for PD disaggregation. | 8998 |
| `--disaggregation-decode-tp` | The decode TP for PD disaggregation. | None |
| `--disaggregation-decode-dp` | The decode DP for PD disaggregation. | None |
| `--disaggregation-prefill-pp` | The prefill PP for PD disaggregation. | 1 |
## Model weight update
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--custom-weight-loader` | Custom weight loader paths. | None |
| `--weight-loader-disable-mmap` | Disable mmap for weight loader. | False |
## Prefill decode disaggregation ## PD-Multiplexing
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--disaggregation-mode` | Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated. | null | | `--enable-pdmux` | Enable PD-Multiplexing. | False |
| `--disaggregation-transfer-backend` | The backend for disaggregation transfer. Default is mooncake. | mooncake | | `--sm-group-num` | Number of SM groups for PD-Multiplexing. | 3 |
| `--disaggregation-bootstrap-port` | Bootstrap server port on the prefill server. Default is 8998. | 8998 |
| `--disaggregation-ib-device` | The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). Default is None, which triggers automatic device detection when mooncake backend is enabled. | None |
| `--num-reserved-decode-tokens` | Number of decode tokens that will have memory reserved when adding new request to the running batch. | 512 |
| `--pdlb-url` | The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer. | None |
...@@ -53,7 +53,7 @@ class ModelConfig: ...@@ -53,7 +53,7 @@ class ModelConfig:
trust_remote_code: bool = True, trust_remote_code: bool = True,
revision: Optional[str] = None, revision: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
model_override_args: Optional[str] = None, model_override_args: str = "{}",
is_embedding: Optional[bool] = None, is_embedding: Optional[bool] = None,
enable_multimodal: Optional[bool] = None, enable_multimodal: Optional[bool] = None,
dtype: str = "auto", dtype: str = "auto",
...@@ -61,13 +61,13 @@ class ModelConfig: ...@@ -61,13 +61,13 @@ class ModelConfig:
override_config_file: Optional[str] = None, override_config_file: Optional[str] = None,
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None, hybrid_kvcache_ratio: Optional[float] = None,
impl: Union[str, ModelImpl] = ModelImpl.AUTO, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None: ) -> None:
self.model_path = model_path self.model_path = model_path
self.revision = revision self.revision = revision
self.quantization = quantization self.quantization = quantization
self.impl = impl self.model_impl = model_impl
# Parse args # Parse args
self.maybe_pull_model_tokenizer_from_remote() self.maybe_pull_model_tokenizer_from_remote()
...@@ -286,7 +286,7 @@ class ModelConfig: ...@@ -286,7 +286,7 @@ class ModelConfig:
dtype=server_args.dtype, dtype=server_args.dtype,
quantization=server_args.quantization, quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
impl=server_args.impl, model_impl=server_args.model_impl,
**kwargs, **kwargs,
) )
......
...@@ -1389,8 +1389,6 @@ class Scheduler( ...@@ -1389,8 +1389,6 @@ class Scheduler(
f += f"#running-req: {running_bs}, " f += f"#running-req: {running_bs}, "
f += f"#queue-req: {len(self.waiting_queue)}, " f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"timestamp: {datetime.datetime.now().isoformat()}"
logger.info(f) logger.info(f)
if self.enable_metrics: if self.enable_metrics:
...@@ -1471,7 +1469,6 @@ class Scheduler( ...@@ -1471,7 +1469,6 @@ class Scheduler(
f"cuda graph: {can_run_cuda_graph}, " f"cuda graph: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
f"timestamp: {datetime.datetime.now().isoformat()}"
) )
logger.info(msg) logger.info(msg)
......
...@@ -56,14 +56,14 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str ...@@ -56,14 +56,14 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str
"if the model is custom)." "if the model is custom)."
) )
model_module = auto_modules["AutoModel"] model_module = auto_modules["AutoModel"]
if model_config.impl == ModelImpl.TRANSFORMERS: if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not model_module.is_backend_compatible(): if not model_module.is_backend_compatible():
raise ValueError( raise ValueError(
f"The Transformers implementation of {arch} is not " f"The Transformers implementation of {arch} is not "
"compatible with vLLM." "compatible with SGLang."
) )
architectures[i] = "TransformersForCausalLM" architectures[i] = "TransformersForCausalLM"
if model_config.impl == ModelImpl.AUTO: if model_config.model_impl == ModelImpl.AUTO:
if not model_module.is_backend_compatible(): if not model_module.is_backend_compatible():
raise ValueError( raise ValueError(
f"{arch} has no SGlang implementation and the Transformers " f"{arch} has no SGlang implementation and the Transformers "
...@@ -97,7 +97,7 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], ...@@ -97,7 +97,7 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
supported_archs = ModelRegistry.get_supported_archs() supported_archs = ModelRegistry.get_supported_archs()
is_native_supported = any(arch in supported_archs for arch in architectures) is_native_supported = any(arch in supported_archs for arch in architectures)
if not is_native_supported or model_config.impl == ModelImpl.TRANSFORMERS: if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS:
architectures = resolve_transformers_arch(model_config, architectures) architectures = resolve_transformers_arch(model_config, architectures)
return ModelRegistry.resolve_model_cls(architectures) return ModelRegistry.resolve_model_cls(architectures)
......
...@@ -20,6 +20,7 @@ import logging ...@@ -20,6 +20,7 @@ import logging
import os import os
import random import random
import tempfile import tempfile
from token import OP
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
...@@ -46,31 +47,28 @@ class ServerArgs: ...@@ -46,31 +47,28 @@ class ServerArgs:
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
skip_server_warmup: bool = False
load_format: str = "auto" load_format: str = "auto"
model_loader_extra_config: str = "{}" model_loader_extra_config: str = "{}"
trust_remote_code: bool = False trust_remote_code: bool = False
dtype: str = "auto"
kv_cache_dtype: str = "auto"
quantization: Optional[str] = None
quantization_param_path: Optional[str] = None
context_length: Optional[int] = None context_length: Optional[int] = None
device: Optional[str] = None
served_model_name: Optional[str] = None
chat_template: Optional[str] = None
completion_template: Optional[str] = None
is_embedding: bool = False is_embedding: bool = False
enable_multimodal: Optional[bool] = None enable_multimodal: Optional[bool] = None
revision: Optional[str] = None revision: Optional[str] = None
hybrid_kvcache_ratio: Optional[float] = None model_impl: str = "auto"
swa_full_tokens_ratio: float = 0.8
impl: str = "auto"
# Port for the HTTP server # HTTP server
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 30000 port: int = 30000
skip_server_warmup: bool = False
warmups: Optional[str] = None
nccl_port: Optional[int] = None nccl_port: Optional[int] = None
# Quantization and data type
dtype: str = "auto"
quantization: Optional[str] = None
quantization_param_path: Optional[str] = None
kv_cache_dtype: str = "auto"
# Memory and scheduling # Memory and scheduling
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None max_running_requests: Optional[int] = None
...@@ -81,8 +79,12 @@ class ServerArgs: ...@@ -81,8 +79,12 @@ class ServerArgs:
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0 cpu_offload_gb: int = 0
page_size: int = 1 page_size: int = 1
hybrid_kvcache_ratio: Optional[float] = None
swa_full_tokens_ratio: float = 0.8
disable_hybrid_swa_memory: bool = False
# Other runtime options # Runtime options
device: Optional[str] = None
tp_size: int = 1 tp_size: int = 1
pp_size: int = 1 pp_size: int = 1
max_micro_batch_size: Optional[int] = None max_micro_batch_size: Optional[int] = None
...@@ -107,8 +109,8 @@ class ServerArgs: ...@@ -107,8 +109,8 @@ class ServerArgs:
enable_metrics: bool = False enable_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False enable_metrics_for_all_schedulers: bool = False
bucket_time_to_first_token: Optional[List[float]] = None bucket_time_to_first_token: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
collect_tokens_histogram: bool = False collect_tokens_histogram: bool = False
decode_log_interval: int = 40 decode_log_interval: int = 40
enable_request_time_stats_logging: bool = False enable_request_time_stats_logging: bool = False
...@@ -116,6 +118,9 @@ class ServerArgs: ...@@ -116,6 +118,9 @@ class ServerArgs:
# API related # API related
api_key: Optional[str] = None api_key: Optional[str] = None
served_model_name: Optional[str] = None
chat_template: Optional[str] = None
completion_template: Optional[str] = None
file_storage_path: str = "sglang_storage" file_storage_path: str = "sglang_storage"
enable_cache_report: bool = False enable_cache_report: bool = False
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = None
...@@ -179,6 +184,14 @@ class ServerArgs: ...@@ -179,6 +184,14 @@ class ServerArgs:
deepep_config: Optional[str] = None deepep_config: Optional[str] = None
moe_dense_tp_size: Optional[int] = None moe_dense_tp_size: Optional[int] = None
# Hierarchical cache
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
hicache_size: int = 0
hicache_write_policy: str = "write_through_selective"
hicache_io_backend: str = ""
hicache_storage_backend: Optional[str] = None
# Double Sparsity # Double Sparsity
enable_double_sparsity: bool = False enable_double_sparsity: bool = False
ds_channel_config_path: Optional[str] = None ds_channel_config_path: Optional[str] = None
...@@ -200,7 +213,6 @@ class ServerArgs: ...@@ -200,7 +213,6 @@ class ServerArgs:
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
enable_mscclpp: bool = False enable_mscclpp: bool = False
disable_overlap_schedule: bool = False disable_overlap_schedule: bool = False
disable_overlap_cg_plan: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
enable_dp_lm_head: bool = False enable_dp_lm_head: bool = False
...@@ -217,20 +229,12 @@ class ServerArgs: ...@@ -217,20 +229,12 @@ class ServerArgs:
enable_memory_saver: bool = False enable_memory_saver: bool = False
allow_auto_truncate: bool = False allow_auto_truncate: bool = False
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
hicache_size: int = 0
hicache_write_policy: str = "write_through_selective"
hicache_io_backend: str = ""
hicache_storage_backend: Optional[str] = None
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
disable_shared_experts_fusion: bool = False disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False enable_triton_kernel_moe: bool = False
warmups: Optional[str] = None
disable_hybrid_swa_memory: bool = False
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
...@@ -238,7 +242,7 @@ class ServerArgs: ...@@ -238,7 +242,7 @@ class ServerArgs:
debug_tensor_dump_inject: bool = False debug_tensor_dump_inject: bool = False
debug_tensor_dump_prefill_only: bool = False debug_tensor_dump_prefill_only: bool = False
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) # PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: str = "null" disaggregation_mode: str = "null"
disaggregation_transfer_backend: str = "mooncake" disaggregation_transfer_backend: str = "mooncake"
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
...@@ -273,6 +277,7 @@ class ServerArgs: ...@@ -273,6 +277,7 @@ class ServerArgs:
logger.warning( logger.warning(
f"Flashinfer MoE is enabled. Shared expert fusion is disabled." f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
) )
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
...@@ -333,56 +338,12 @@ class ServerArgs: ...@@ -333,56 +338,12 @@ class ServerArgs:
self.mem_fraction_static = 0.88 self.mem_fraction_static = 0.88
# Lazy init to avoid circular import # Lazy init to avoid circular import
# Multimodal models need more memory for the image processor
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
# Multimodal models need more memory for the image processor
model_config = ModelConfig.from_server_args(self) model_config = ModelConfig.from_server_args(self)
if model_config.is_multimodal:
vision_config = getattr(model_config.hf_config, "vision_config", None) self.adjust_mem_fraction_for_vlm(model_config)
if model_config.is_multimodal and vision_config:
# roughly reduce the mem_fraction_static base on params of Vit
original_server_arg_mem_fraction = self.mem_fraction_static
# a base mem_fraction_static factor for regular Vit
base_mem_fraction_reduction_ratio = 0.95
vit_num_layers = getattr(vision_config, "num_hidden_layers", 24)
vit_hidden_size = getattr(vision_config, "hidden_size", 1024)
# baseline ViT params (ViT-L/14)
baseline_vit_layers = 24
baseline_vit_hidden_size = 1024
# weight params count
current_complexity_score = vit_num_layers * (vit_hidden_size**2)
baseline_complexity_score = baseline_vit_layers * (
baseline_vit_hidden_size**2
)
complexity_ratio = (
current_complexity_score / baseline_complexity_score
if baseline_complexity_score > 0
else 1.0
)
# every time the complexity grows 100%, adjust final factor for 10%
sensitivity_scale = 0.1
dynamic_adjustment_factor = 1.0 - sensitivity_scale * (
complexity_ratio - 1.0
)
dynamic_adjustment_factor = max(
0.8, min(1.05, dynamic_adjustment_factor)
)
final_overall_factor = (
base_mem_fraction_reduction_ratio * dynamic_adjustment_factor
)
self.mem_fraction_static = (
original_server_arg_mem_fraction * final_overall_factor
)
logger.warning(
f"Multimodal model: Dynamically adjusted --mem-fraction-static "
f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}."
)
# Set chunked prefill size, which depends on the gpu memory capacity # Set chunked prefill size, which depends on the gpu memory capacity
if self.chunked_prefill_size is None: if self.chunked_prefill_size is None:
...@@ -406,23 +367,6 @@ class ServerArgs: ...@@ -406,23 +367,6 @@ class ServerArgs:
else: else:
self.cuda_graph_max_bs = 80 self.cuda_graph_max_bs = 80
assert self.moe_dense_tp_size in {
1,
None,
}, "moe_dense_tp_size only support 1 and None currently"
if self.attention_backend == "flashmla":
logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
)
self.page_size = 64
if self.attention_backend == "cutlass_mla":
logger.warning(
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
)
self.page_size = 128
# Set kernel backends for hpu device # Set kernel backends for hpu device
if self.device == "hpu": if self.device == "hpu":
self.attention_backend = "torch_native" self.attention_backend = "torch_native"
...@@ -451,6 +395,18 @@ class ServerArgs: ...@@ -451,6 +395,18 @@ class ServerArgs:
) )
self.page_size = 128 self.page_size = 128
if self.attention_backend == "flashmla":
logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
)
self.page_size = 64
if self.attention_backend == "cutlass_mla":
logger.warning(
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
)
self.page_size = 128
# Choose grammar backend # Choose grammar backend
if self.grammar_backend is None: if self.grammar_backend is None:
self.grammar_backend = "xgrammar" self.grammar_backend = "xgrammar"
...@@ -482,12 +438,6 @@ class ServerArgs: ...@@ -482,12 +438,6 @@ class ServerArgs:
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
) )
if self.pp_size > 1:
self.disable_overlap_schedule = True
logger.warning(
"Pipeline parallelism is incompatible with overlap schedule."
)
if self.enable_eplb and (self.expert_distribution_recorder_mode is None): if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
self.expert_distribution_recorder_mode = "stat" self.expert_distribution_recorder_mode = "stat"
logger.info( logger.info(
...@@ -513,6 +463,13 @@ class ServerArgs: ...@@ -513,6 +463,13 @@ class ServerArgs:
elif self.expert_distribution_recorder_mode is not None: elif self.expert_distribution_recorder_mode is not None:
self.expert_distribution_recorder_buffer_size = 1000 self.expert_distribution_recorder_buffer_size = 1000
# Pipeline parallelism
if self.pp_size > 1:
self.disable_overlap_schedule = True
logger.warning(
"Pipeline parallelism is incompatible with overlap schedule."
)
# Speculative Decoding # Speculative Decoding
if self.speculative_algorithm == "NEXTN": if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE # NEXTN shares the same implementation of EAGLE
...@@ -533,8 +490,7 @@ class ServerArgs: ...@@ -533,8 +490,7 @@ class ServerArgs:
"eagle speculative decoding." "eagle speculative decoding."
) )
model_arch = get_model_arch(self) model_arch = self.get_hf_config().architectures[0]
if model_arch == "DeepseekV3ForCausalLM": if model_arch == "DeepseekV3ForCausalLM":
# Auto set draft_model_path DeepSeek-V3/R1 # Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None: if self.speculative_draft_model_path is None:
...@@ -624,17 +580,9 @@ class ServerArgs: ...@@ -624,17 +580,9 @@ class ServerArgs:
if self.custom_weight_loader is None: if self.custom_weight_loader is None:
self.custom_weight_loader = [] self.custom_weight_loader = []
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp)
smaller_tp = min(decode_tp, prefill_tp)
assert larger_tp % smaller_tp == 0, (
"Different tp size is supported only when one tp is multiple of the other. "
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
)
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args # Model and tokenizer
parser.add_argument( parser.add_argument(
"--model-path", "--model-path",
"--model", "--model",
...@@ -648,24 +596,6 @@ class ServerArgs: ...@@ -648,24 +596,6 @@ class ServerArgs:
default=ServerArgs.tokenizer_path, default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.", help="The path of the tokenizer.",
) )
parser.add_argument(
"--host",
type=str,
default=ServerArgs.host,
help="The host of the HTTP server.",
)
parser.add_argument(
"--port",
type=int,
default=ServerArgs.port,
help="The port of the HTTP server.",
)
parser.add_argument(
"--nccl-port",
type=int,
default=ServerArgs.nccl_port,
help="The port for NCCL distributed environment setup. Defaults to a random port.",
)
parser.add_argument( parser.add_argument(
"--tokenizer-mode", "--tokenizer-mode",
type=str, type=str,
...@@ -680,11 +610,6 @@ class ServerArgs: ...@@ -680,11 +610,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request.", help="If set, skip init tokenizer and pass input_ids in generate request.",
) )
parser.add_argument(
"--skip-server-warmup",
action="store_true",
help="If set, skip warmup.",
)
parser.add_argument( parser.add_argument(
"--load-format", "--load-format",
type=str, type=str,
...@@ -730,6 +655,77 @@ class ServerArgs: ...@@ -730,6 +655,77 @@ class ServerArgs:
action="store_true", action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
) )
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--enable-multimodal",
default=ServerArgs.enable_multimodal,
action="store_true",
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="The specific model version to use. It can be a branch "
"name, a tag name, or a commit id. If unspecified, will use "
"the default version.",
)
parser.add_argument(
"--model-impl",
type=str,
default=ServerArgs.model_impl,
help="Which implementation of the model to use.\n\n"
'* "auto" will try to use the SGLang implementation if it exists '
"and fall back to the Transformers implementation if no SGLang "
"implementation is available.\n"
'* "sglang" will use the SGLang model implementation.\n'
'* "transformers" will use the Transformers model '
"implementation.\n",
)
# HTTP server
parser.add_argument(
"--host",
type=str,
default=ServerArgs.host,
help="The host of the HTTP server.",
)
parser.add_argument(
"--port",
type=int,
default=ServerArgs.port,
help="The port of the HTTP server.",
)
parser.add_argument(
"--skip-server-warmup",
action="store_true",
help="If set, skip warmup.",
)
parser.add_argument(
"--warmups",
type=str,
required=False,
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
)
parser.add_argument(
"--nccl-port",
type=int,
default=ServerArgs.nccl_port,
help="The port for NCCL distributed environment setup. Defaults to a random port.",
)
# Quantization and data type
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
type=str, type=str,
...@@ -744,13 +740,6 @@ class ServerArgs: ...@@ -744,13 +740,6 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.\n' '* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.', '* "float32" for FP32 precision.',
) )
parser.add_argument(
"--kv-cache-dtype",
type=str,
default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
parser.add_argument( parser.add_argument(
"--quantization", "--quantization",
type=str, type=str,
...@@ -785,65 +774,11 @@ class ServerArgs: ...@@ -785,65 +774,11 @@ class ServerArgs:
"default to 1.0, which may cause accuracy issues. ", "default to 1.0, which may cause accuracy issues. ",
) )
parser.add_argument( parser.add_argument(
"--context-length", "--kv-cache-dtype",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument(
"--device",
type=str,
default=ServerArgs.device,
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
)
parser.add_argument(
"--served-model-name",
type=str,
default=ServerArgs.served_model_name,
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
)
parser.add_argument(
"--chat-template",
type=str,
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
)
parser.add_argument(
"--completion-template",
type=str,
default=ServerArgs.completion_template,
help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--enable-multimodal",
default=ServerArgs.enable_multimodal,
action="store_true",
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="The specific model version to use. It can be a branch "
"name, a tag name, or a commit id. If unspecified, will use "
"the default version.",
)
parser.add_argument(
"--impl",
type=str, type=str,
default=ServerArgs.impl, default=ServerArgs.kv_cache_dtype,
help="Which implementation of the model to use.\n\n" choices=["auto", "fp8_e5m2", "fp8_e4m3"],
'* "auto" will try to use the SGLang implementation if it exists ' help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
"and fall back to the Transformers implementation if no SGLang "
"implementation is available.\n"
'* "sglang" will use the SGLang model implementation.\n'
'* "transformers" will use the Transformers model '
"implementation.\n",
) )
# Memory and scheduling # Memory and scheduling
...@@ -928,7 +863,13 @@ class ServerArgs: ...@@ -928,7 +863,13 @@ class ServerArgs:
help="Disable the hybrid SWA memory.", help="Disable the hybrid SWA memory.",
) )
# Other runtime options # Runtime options
parser.add_argument(
"--device",
type=str,
default=ServerArgs.device,
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
)
parser.add_argument( parser.add_argument(
"--tensor-parallel-size", "--tensor-parallel-size",
"--tp-size", "--tp-size",
...@@ -970,7 +911,7 @@ class ServerArgs: ...@@ -970,7 +911,7 @@ class ServerArgs:
"--constrained-json-whitespace-pattern", "--constrained-json-whitespace-pattern",
type=str, type=str,
default=ServerArgs.constrained_json_whitespace_pattern, default=ServerArgs.constrained_json_whitespace_pattern,
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", help="(outlines backend only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
) )
parser.add_argument( parser.add_argument(
"--watchdog-timeout", "--watchdog-timeout",
...@@ -1083,12 +1024,6 @@ class ServerArgs: ...@@ -1083,12 +1024,6 @@ class ServerArgs:
default=ServerArgs.collect_tokens_histogram, default=ServerArgs.collect_tokens_histogram,
help="Collect prompt/generation tokens histogram.", help="Collect prompt/generation tokens histogram.",
) )
parser.add_argument(
"--kv-events-config",
type=str,
default=None,
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
)
parser.add_argument( parser.add_argument(
"--decode-log-interval", "--decode-log-interval",
type=int, type=int,
...@@ -1101,6 +1036,12 @@ class ServerArgs: ...@@ -1101,6 +1036,12 @@ class ServerArgs:
default=ServerArgs.enable_request_time_stats_logging, default=ServerArgs.enable_request_time_stats_logging,
help="Enable per request time stats logging", help="Enable per request time stats logging",
) )
parser.add_argument(
"--kv-events-config",
type=str,
default=None,
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
)
# API related # API related
parser.add_argument( parser.add_argument(
...@@ -1109,6 +1050,24 @@ class ServerArgs: ...@@ -1109,6 +1050,24 @@ class ServerArgs:
default=ServerArgs.api_key, default=ServerArgs.api_key,
help="Set API key of the server. It is also used in the OpenAI API compatible server.", help="Set API key of the server. It is also used in the OpenAI API compatible server.",
) )
parser.add_argument(
"--served-model-name",
type=str,
default=ServerArgs.served_model_name,
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
)
parser.add_argument(
"--chat-template",
type=str,
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
)
parser.add_argument(
"--completion-template",
type=str,
default=ServerArgs.completion_template,
help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
)
parser.add_argument( parser.add_argument(
"--file-storage-path", "--file-storage-path",
type=str, type=str,
...@@ -1427,6 +1386,46 @@ class ServerArgs: ...@@ -1427,6 +1386,46 @@ class ServerArgs:
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.", help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
) )
# Hierarchical cache
parser.add_argument(
"--enable-hierarchical-cache",
action="store_true",
help="Enable hierarchical cache",
)
parser.add_argument(
"--hicache-ratio",
type=float,
default=ServerArgs.hicache_ratio,
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
)
parser.add_argument(
"--hicache-size",
type=int,
default=ServerArgs.hicache_size,
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
)
parser.add_argument(
"--hicache-write-policy",
type=str,
choices=["write_back", "write_through", "write_through_selective"],
default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.",
)
parser.add_argument(
"--hicache-io-backend",
type=str,
choices=["direct", "kernel"],
default=ServerArgs.hicache_io_backend,
help="The IO backend for KV cache transfer between CPU and GPU",
)
parser.add_argument(
"--hicache-storage-backend",
type=str,
choices=["file"], # todo, mooncake
default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.",
)
# Double Sparsity # Double Sparsity
parser.add_argument( parser.add_argument(
"--enable-double-sparsity", "--enable-double-sparsity",
...@@ -1619,44 +1618,6 @@ class ServerArgs: ...@@ -1619,44 +1618,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable users to pass custom logit processors to the server (disabled by default for security)", help="Enable users to pass custom logit processors to the server (disabled by default for security)",
) )
parser.add_argument(
"--enable-hierarchical-cache",
action="store_true",
help="Enable hierarchical cache",
)
parser.add_argument(
"--hicache-ratio",
type=float,
default=ServerArgs.hicache_ratio,
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
)
parser.add_argument(
"--hicache-size",
type=int,
default=ServerArgs.hicache_size,
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
)
parser.add_argument(
"--hicache-write-policy",
type=str,
choices=["write_back", "write_through", "write_through_selective"],
default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.",
)
parser.add_argument(
"--hicache-io-backend",
type=str,
choices=["direct", "kernel"],
default=ServerArgs.hicache_io_backend,
help="The IO backend for KV cache transfer between CPU and GPU",
)
parser.add_argument(
"--hicache-storage-backend",
type=str,
choices=["file"], # todo, mooncacke
default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.",
)
parser.add_argument( parser.add_argument(
"--flashinfer-mla-disable-ragged", "--flashinfer-mla-disable-ragged",
action="store_true", action="store_true",
...@@ -1687,13 +1648,6 @@ class ServerArgs: ...@@ -1687,13 +1648,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Use triton moe grouped gemm kernel.", help="Use triton moe grouped gemm kernel.",
) )
parser.add_argument(
"--warmups",
type=str,
required=False,
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
)
# Debug tensor dumps # Debug tensor dumps
parser.add_argument( parser.add_argument(
...@@ -1720,7 +1674,7 @@ class ServerArgs: ...@@ -1720,7 +1674,7 @@ class ServerArgs:
help="Only dump the tensors for prefill requests (i.e. batch size > 1).", help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
) )
# Disaggregation # PD disaggregation
parser.add_argument( parser.add_argument(
"--disaggregation-mode", "--disaggregation-mode",
type=str, type=str,
...@@ -1779,6 +1733,8 @@ class ServerArgs: ...@@ -1779,6 +1733,8 @@ class ServerArgs:
default=None, default=None,
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.", help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
) )
# Custom weight loader
parser.add_argument( parser.add_argument(
"--custom-weight-loader", "--custom-weight-loader",
type=str, type=str,
...@@ -1791,6 +1747,8 @@ class ServerArgs: ...@@ -1791,6 +1747,8 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable PD-Multiplexing, PD running on greenctx stream.", help="Enable PD-Multiplexing, PD running on greenctx stream.",
) )
# For PD-Multiplexing
parser.add_argument( parser.add_argument(
"--sm-group-num", "--sm-group-num",
type=int, type=int,
...@@ -1818,6 +1776,17 @@ class ServerArgs: ...@@ -1818,6 +1776,17 @@ class ServerArgs:
else: else:
return f"http://{self.host}:{self.port}" return f"http://{self.host}:{self.port}"
def get_hf_config(self):
kwargs = {}
hf_config = get_config(
self.model_path,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
model_override_args=json.loads(self.json_model_override_args),
**kwargs,
)
return hf_config
def check_server_args(self): def check_server_args(self):
assert ( assert (
self.tp_size * self.pp_size self.tp_size * self.pp_size
...@@ -1842,6 +1811,11 @@ class ServerArgs: ...@@ -1842,6 +1811,11 @@ class ServerArgs:
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive" assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
assert self.moe_dense_tp_size in {
1,
None,
}, "moe_dense_tp_size only support 1 and None currently"
if isinstance(self.lora_paths, list): if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths lora_paths = self.lora_paths
self.lora_paths = {} self.lora_paths = {}
...@@ -1852,6 +1826,56 @@ class ServerArgs: ...@@ -1852,6 +1826,56 @@ class ServerArgs:
else: else:
self.lora_paths[lora_path] = lora_path self.lora_paths[lora_path] = lora_path
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp)
smaller_tp = min(decode_tp, prefill_tp)
assert larger_tp % smaller_tp == 0, (
"Different tp size is supported only when one tp is multiple of the other. "
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
)
def adjust_mem_fraction_for_vlm(self, model_config):
vision_config = getattr(model_config.hf_config, "vision_config", None)
if vision_config is None:
return
# roughly reduce the mem_fraction_static base on params of Vit
original_server_arg_mem_fraction = self.mem_fraction_static
# a base mem_fraction_static factor for regular Vit
base_mem_fraction_reduction_ratio = 0.95
vit_num_layers = getattr(vision_config, "num_hidden_layers", 24)
vit_hidden_size = getattr(vision_config, "hidden_size", 1024)
# baseline ViT params (ViT-L/14)
baseline_vit_layers = 24
baseline_vit_hidden_size = 1024
# weight params count
current_complexity_score = vit_num_layers * (vit_hidden_size**2)
baseline_complexity_score = baseline_vit_layers * (baseline_vit_hidden_size**2)
complexity_ratio = (
current_complexity_score / baseline_complexity_score
if baseline_complexity_score > 0
else 1.0
)
# every time the complexity grows 100%, adjust final factor for 10%
sensitivity_scale = 0.1
dynamic_adjustment_factor = 1.0 - sensitivity_scale * (complexity_ratio - 1.0)
dynamic_adjustment_factor = max(0.8, min(1.05, dynamic_adjustment_factor))
final_overall_factor = (
base_mem_fraction_reduction_ratio * dynamic_adjustment_factor
)
self.mem_fraction_static = (
original_server_arg_mem_fraction * final_overall_factor
)
logger.warning(
f"Multimodal model: Dynamically adjusted --mem-fraction-static "
f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}."
)
def prepare_server_args(argv: List[str]) -> ServerArgs: def prepare_server_args(argv: List[str]) -> ServerArgs:
""" """
...@@ -1895,16 +1919,16 @@ class PortArgs: ...@@ -1895,16 +1919,16 @@ class PortArgs:
@staticmethod @staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
if server_args.nccl_port is None: if server_args.nccl_port is None:
port = server_args.port + random.randint(100, 1000) nccl_port = server_args.port + random.randint(100, 1000)
while True: while True:
if is_port_available(port): if is_port_available(nccl_port):
break break
if port < 60000: if nccl_port < 60000:
port += 42 nccl_port += 42
else: else:
port -= 43 nccl_port -= 43
else: else:
port = server_args.nccl_port nccl_port = server_args.nccl_port
if not server_args.enable_dp_attention: if not server_args.enable_dp_attention:
# Normal case, use IPC within a single node # Normal case, use IPC within a single node
...@@ -1912,7 +1936,7 @@ class PortArgs: ...@@ -1912,7 +1936,7 @@ class PortArgs:
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
nccl_port=port, nccl_port=nccl_port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
) )
...@@ -1942,7 +1966,7 @@ class PortArgs: ...@@ -1942,7 +1966,7 @@ class PortArgs:
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
nccl_port=port, nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}", rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}", metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
) )
...@@ -1969,31 +1993,13 @@ class DeprecatedAction(argparse.Action): ...@@ -1969,31 +1993,13 @@ class DeprecatedAction(argparse.Action):
raise ValueError(self.help) raise ValueError(self.help)
def get_model_arch(args: ServerArgs):
hf_config = get_config(
args.model_path,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
model_override_args=json.loads(args.json_model_override_args),
)
return hf_config.architectures[0]
def auto_choose_speculative_params(self: ServerArgs): def auto_choose_speculative_params(self: ServerArgs):
""" """
Automatically choose the parameters for speculative decoding. Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
""" """
kwargs = {} hf_config = self.get_hf_config()
hf_config = get_config(
self.model_path,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
model_override_args=json.loads(self.json_model_override_args),
**kwargs,
)
arch = hf_config.architectures[0] arch = hf_config.architectures[0]
if arch in ["LlamaForCausalLM"]: if arch in ["LlamaForCausalLM"]:
......
...@@ -481,7 +481,7 @@ class SRTRunner: ...@@ -481,7 +481,7 @@ class SRTRunner:
torch_dtype: torch.dtype, torch_dtype: torch.dtype,
model_type: str, model_type: str,
tp_size: int = 1, tp_size: int = 1,
impl: str = "auto", model_impl: str = "auto",
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None, lora_paths: List[str] = None,
max_loras_per_batch: int = 4, max_loras_per_batch: int = 4,
...@@ -525,7 +525,7 @@ class SRTRunner: ...@@ -525,7 +525,7 @@ class SRTRunner:
tp_size=tp_size, tp_size=tp_size,
dtype=get_dtype_str(torch_dtype), dtype=get_dtype_str(torch_dtype),
port=port, port=port,
impl=impl, model_impl=model_impl,
torchao_config=torchao_config, torchao_config=torchao_config,
mem_fraction_static=mem_fraction_static, mem_fraction_static=mem_fraction_static,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
......
...@@ -27,7 +27,7 @@ class TestTransformersFallbackEndpoint(CustomTestCase): ...@@ -27,7 +27,7 @@ class TestTransformersFallbackEndpoint(CustomTestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--impl", "transformers"], other_args=["--model-impl", "transformers"],
) )
cls.mmlu_lower_bound = 0.65 cls.mmlu_lower_bound = 0.65
cls.gsm8k_lower_bound = 0.65 cls.gsm8k_lower_bound = 0.65
...@@ -76,7 +76,7 @@ class TestTransformersFallbackTorchAO(TestTransformersFallbackEndpoint): ...@@ -76,7 +76,7 @@ class TestTransformersFallbackTorchAO(TestTransformersFallbackEndpoint):
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--impl", "--model-impl",
"transformers", "transformers",
"--torchao-config", "--torchao-config",
"int4wo-128", "int4wo-128",
...@@ -127,7 +127,7 @@ class TestTransformersFallbackEngine(CustomTestCase): ...@@ -127,7 +127,7 @@ class TestTransformersFallbackEngine(CustomTestCase):
tp_size=model_case.tp_size, tp_size=model_case.tp_size,
torch_dtype=model_case.torch_dtype, torch_dtype=model_case.torch_dtype,
model_type="generation", model_type="generation",
impl="transformers", model_impl="transformers",
trust_remote_code=model_case.trust_remote_code, trust_remote_code=model_case.trust_remote_code,
torchao_config=model_case.torchao_config, torchao_config=model_case.torchao_config,
) as srt_runner: ) as srt_runner:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment