Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
...@@ -30,11 +30,20 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -30,11 +30,20 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
else: else:
# Default: Mixtral # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
......
...@@ -35,6 +35,15 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -35,6 +35,15 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else: else:
# Default: Mixtral # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
......
...@@ -397,6 +397,15 @@ def main(args: argparse.Namespace): ...@@ -397,6 +397,15 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else: else:
# Default: Mixtral # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
......
...@@ -210,8 +210,7 @@ ...@@ -210,8 +210,7 @@
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
"print_highlight(response.text)\n", "print_highlight(response.text)\n",
"assert response.json()[\"success\"] is True\n", "assert response.json()[\"success\"] is True\n",
"assert response.json()[\"message\"] == \"Succeeded to update model weights.\"\n", "assert response.json()[\"message\"] == \"Succeeded to update model weights.\""
"assert response.json().keys() == {\"success\", \"message\"}"
] ]
}, },
{ {
...@@ -411,7 +410,7 @@ ...@@ -411,7 +410,7 @@
" },\n", " },\n",
")\n", ")\n",
"output = response.json()\n", "output = response.json()\n",
"output_tokens = output[\"token_ids\"]\n", "output_tokens = output[\"output_ids\"]\n",
"\n", "\n",
"output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n", "output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n",
"print_highlight(f\"Tokenized Output: {output_tokens}\")\n", "print_highlight(f\"Tokenized Output: {output_tokens}\")\n",
......
...@@ -96,7 +96,6 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -96,7 +96,6 @@ Please consult the documentation below to learn more about the parameters you ma
* `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine. * `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine.
* `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance. * `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance.
* `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU. * `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU.
* `prefill_only_one_req`: When this flag is turned on, the engine prefills only one request at a time.
## Other runtime options ## Other runtime options
......
...@@ -96,7 +96,10 @@ dev_cpu = ["sglang[all_cpu]", "sglang[test]"] ...@@ -96,7 +96,10 @@ dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.package-data] [tool.setuptools.package-data]
"sglang" = ["srt/layers/moe/fused_moe_triton/configs/*.json", "srt/layers/quantization/configs/*.json"] "sglang" = [
"srt/layers/moe/fused_moe_triton/configs/*.json",
"srt/layers/quantization/configs/*.json",
]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
exclude = [ exclude = [
......
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server. - `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server. - `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
- `bench_serving.py`: Benchmark online serving with dynamic requests. - `bench_serving.py`: Benchmark online serving with dynamic requests.
- `check_env.py`: Check the environment variables. - `check_env.py`: Check the environment variables and dependencies.
- `global_config.py`: The global configs and constants. - `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server. - `launch_server.py`: The entry point for launching the local server.
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset. - `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
- `profiler.py`: Profile a running server.
- `utils.py`: Common utilities. - `utils.py`: Common utilities.
- `version.py`: Version info.
...@@ -56,6 +56,7 @@ class BenchArgs: ...@@ -56,6 +56,7 @@ class BenchArgs:
profile: bool = False profile: bool = False
skip_warmup: bool = False skip_warmup: bool = False
do_not_exit: bool = False do_not_exit: bool = False
prompt_suffix: str = ""
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -177,6 +178,12 @@ class BenchArgs: ...@@ -177,6 +178,12 @@ class BenchArgs:
action="store_true", action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
) )
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -216,6 +223,10 @@ def throughput_test_once( ...@@ -216,6 +223,10 @@ def throughput_test_once(
] ]
if profile: if profile:
assert (
"SGLANG_TORCH_PROFILER_DIR" in os.environ
), "Please set SGLANG_TORCH_PROFILER_DIR."
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
backend.start_profile() backend.start_profile()
st = time.perf_counter() st = time.perf_counter()
...@@ -229,6 +240,8 @@ def throughput_test_once( ...@@ -229,6 +240,8 @@ def throughput_test_once(
if backend_name == "runtime": if backend_name == "runtime":
gen_out = json.loads(gen_out) gen_out = json.loads(gen_out)
server_info = backend.get_server_info()
measurement_results["total_latency"] = latency measurement_results["total_latency"] = latency
measurement_results["total_output_tokens"] = sum( measurement_results["total_output_tokens"] = sum(
o["meta_info"]["completion_tokens"] for o in gen_out o["meta_info"]["completion_tokens"] for o in gen_out
...@@ -246,6 +259,7 @@ def throughput_test_once( ...@@ -246,6 +259,7 @@ def throughput_test_once(
measurement_results["total_input_tokens"] measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"] + measurement_results["total_output_tokens"]
) / latency ) / latency
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
return measurement_results return measurement_results
...@@ -361,6 +375,11 @@ def throughput_test( ...@@ -361,6 +375,11 @@ def throughput_test(
print( print(
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"]) "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
) )
print(
"{:<40} {:<10.2f}".format(
"Last generation throughput (tok/s):", result["last_gen_throughput"]
)
)
print( print(
"{:<40} {:<10.2f}".format( "{:<40} {:<10.2f}".format(
"Request throughput (req/s):", result["request_throughput"] "Request throughput (req/s):", result["request_throughput"]
......
...@@ -8,7 +8,6 @@ Usage: ...@@ -8,7 +8,6 @@ Usage:
python3 -m sglang.bench_serving --backend sglang --num-prompt 10 python3 -m sglang.bench_serving --backend sglang --num-prompt 10
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
""" """
import argparse import argparse
...@@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str: ...@@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text return text[len(prefix) :] if text.startswith(prefix) else text
def remove_suffix(text: str, suffix: str) -> str:
return text[: -len(suffix)] if text.endswith(suffix) else text
def get_auth_headers() -> Dict[str, str]: def get_auth_headers() -> Dict[str, str]:
api_key = os.environ.get("OPENAI_API_KEY") api_key = os.environ.get("OPENAI_API_KEY")
if api_key: if api_key:
...@@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]: ...@@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]:
return {} return {}
# trt llm not support ignore_eos # trt llm does not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm( async def async_request_trt_llm(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
...@@ -179,6 +182,7 @@ async def async_request_openai_completions( ...@@ -179,6 +182,7 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
...@@ -215,11 +219,14 @@ async def async_request_openai_completions( ...@@ -215,11 +219,14 @@ async def async_request_openai_completions(
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"] generated_text += data["choices"][0]["text"]
output_len = data.get("usage", {}).get(
"completion_tokens", output_len
)
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = latency output.latency = latency
output.output_len = request_func_input.output_len output.output_len = output_len
else: else:
output.error = response.reason or "" output.error = response.reason or ""
output.success = False output.success = False
...@@ -339,9 +346,11 @@ async def async_request_sglang_generate( ...@@ -339,9 +346,11 @@ async def async_request_sglang_generate(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
last_output_len = 0
try: try:
async with session.post( async with session.post(
url=api_url, json=payload, headers=headers url=api_url, json=payload, headers=headers
...@@ -365,6 +374,9 @@ async def async_request_sglang_generate( ...@@ -365,6 +374,9 @@ async def async_request_sglang_generate(
# want to check a token was generated # want to check a token was generated
if data["text"]: if data["text"]:
timestamp = time.perf_counter() timestamp = time.perf_counter()
generated_text = data["text"]
output_len = data["meta_info"]["completion_tokens"]
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
...@@ -372,7 +384,13 @@ async def async_request_sglang_generate( ...@@ -372,7 +384,13 @@ async def async_request_sglang_generate(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - most_recent_timestamp) num_new_tokens = output_len - last_output_len
if num_new_tokens == 0:
continue
adjust_itl = (
timestamp - most_recent_timestamp
) / num_new_tokens
output.itl.extend([adjust_itl] * num_new_tokens)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text = data["text"] generated_text = data["text"]
...@@ -380,7 +398,7 @@ async def async_request_sglang_generate( ...@@ -380,7 +398,7 @@ async def async_request_sglang_generate(
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = latency output.latency = latency
output.output_len = request_func_input.output_len output.output_len = output_len
else: else:
output.error = response.reason or "" output.error = response.reason or ""
output.success = False output.success = False
...@@ -388,6 +406,7 @@ async def async_request_sglang_generate( ...@@ -388,6 +406,7 @@ async def async_request_sglang_generate(
output.success = False output.success = False
exc_info = sys.exc_info() exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info)) output.error = "".join(traceback.format_exception(*exc_info))
print(f"{output.error=}")
if pbar: if pbar:
pbar.update(1) pbar.update(1)
...@@ -461,6 +480,7 @@ def get_dataset(args, tokenizer): ...@@ -461,6 +480,7 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len, fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len, context_len=args.sharegpt_context_len,
prompt_suffix=args.prompt_suffix,
apply_chat_template=args.apply_chat_template, apply_chat_template=args.apply_chat_template,
) )
elif args.dataset_name == "random": elif args.dataset_name == "random":
...@@ -521,7 +541,9 @@ class BenchmarkMetrics: ...@@ -521,7 +541,9 @@ class BenchmarkMetrics:
mean_itl_ms: float mean_itl_ms: float
median_itl_ms: float median_itl_ms: float
std_itl_ms: float std_itl_ms: float
p95_itl_ms: float
p99_itl_ms: float p99_itl_ms: float
max_itl_ms: float
mean_e2e_latency_ms: float mean_e2e_latency_ms: float
median_e2e_latency_ms: float median_e2e_latency_ms: float
std_e2e_latency_ms: float std_e2e_latency_ms: float
...@@ -572,6 +594,7 @@ def sample_sharegpt_requests( ...@@ -572,6 +594,7 @@ def sample_sharegpt_requests(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None, context_len: Optional[int] = None,
prompt_suffix: Optional[str] = "",
apply_chat_template=False, apply_chat_template=False,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
...@@ -584,11 +607,19 @@ def sample_sharegpt_requests( ...@@ -584,11 +607,19 @@ def sample_sharegpt_requests(
# Load the dataset. # Load the dataset.
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [ dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"]) (
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
for data in dataset for data in dataset
] ]
...@@ -603,6 +634,8 @@ def sample_sharegpt_requests( ...@@ -603,6 +634,8 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompt = dataset[i][0] prompt = dataset[i][0]
if prompt_suffix:
prompt = prompt
if apply_chat_template: if apply_chat_template:
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
...@@ -666,10 +699,17 @@ def sample_random_requests( ...@@ -666,10 +699,17 @@ def sample_random_requests(
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [ dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"]) (
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
for data in dataset for data in dataset
] ]
# Shuffle the dataset. # Shuffle the dataset.
...@@ -895,7 +935,9 @@ def calculate_metrics( ...@@ -895,7 +935,9 @@ def calculate_metrics(
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000,
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000, p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
max_itl_ms=np.max(itls or 0) * 1000,
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
median_e2e_latency_ms=np.median(e2e_latencies) * 1000, median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000, std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
...@@ -919,6 +961,7 @@ async def benchmark( ...@@ -919,6 +961,7 @@ async def benchmark(
lora_name: str, lora_name: str,
extra_request_body: Dict[str, Any], extra_request_body: Dict[str, Any],
profile: bool, profile: bool,
pd_seperated: bool = False,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
...@@ -1004,6 +1047,17 @@ async def benchmark( ...@@ -1004,6 +1047,17 @@ async def benchmark(
if pbar is not None: if pbar is not None:
pbar.close() pbar.close()
if "sglang" in backend:
server_info = requests.get(base_url + "/get_server_info")
if pd_seperated:
accept_length = server_info.json()["decode"][0].get(
"avg_spec_accept_length", None
)
else:
accept_length = server_info.json().get("avg_spec_accept_length", None)
else:
accept_length = None
# Compute metrics and print results # Compute metrics and print results
benchmark_duration = time.perf_counter() - benchmark_start_time benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, output_lens = calculate_metrics( metrics, output_lens = calculate_metrics(
...@@ -1053,6 +1107,8 @@ async def benchmark( ...@@ -1053,6 +1107,8 @@ async def benchmark(
) )
) )
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
if accept_length:
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print( print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
...@@ -1066,16 +1122,12 @@ async def benchmark( ...@@ -1066,16 +1122,12 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print( print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
)
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
print("=" * 50) print("=" * 50)
if ( if (
...@@ -1117,8 +1169,10 @@ async def benchmark( ...@@ -1117,8 +1169,10 @@ async def benchmark(
"mean_itl_ms": metrics.mean_itl_ms, "mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms, "median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms, "std_itl_ms": metrics.std_itl_ms,
"p95_itl_ms": metrics.p95_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms, "p99_itl_ms": metrics.p99_itl_ms,
"concurrency": metrics.concurrency, "concurrency": metrics.concurrency,
"accept_length": accept_length,
} }
else: else:
print(f"Error running benchmark for request rate: {request_rate}") print(f"Error running benchmark for request rate: {request_rate}")
...@@ -1151,14 +1205,6 @@ async def benchmark( ...@@ -1151,14 +1205,6 @@ async def benchmark(
return result return result
def parse_request_rate_range(request_rate_range):
if len(request_rate_range.split(",")) == 3:
start, stop, step = map(int, request_rate_range.split(","))
return list(range(start, stop, step))
else:
return list(map(int, request_rate_range.split(",")))
def check_chat_template(model_path): def check_chat_template(model_path):
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
...@@ -1168,6 +1214,12 @@ def check_chat_template(model_path): ...@@ -1168,6 +1214,12 @@ def check_chat_template(model_path):
return False return False
def set_global_args(args_: argparse.Namespace):
"""Set the global args."""
global args
args = args_
def run_benchmark(args_: argparse.Namespace): def run_benchmark(args_: argparse.Namespace):
global args global args
args = args_ args = args_
...@@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "max_concurrency"): if not hasattr(args, "max_concurrency"):
args.max_concurrency = None args.max_concurrency = None
print(f"benchmark_args={args}")
# Set global environments # Set global environments
set_ulimit() set_ulimit()
random.seed(args.seed) random.seed(args.seed)
...@@ -1272,49 +1326,26 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1272,49 +1326,26 @@ def run_benchmark(args_: argparse.Namespace):
backend = args.backend backend = args.backend
model_id = args.model model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer = get_tokenizer(tokenizer_id) tokenizer = get_tokenizer(tokenizer_id)
input_requests = get_dataset(args, tokenizer) input_requests = get_dataset(args, tokenizer)
if not args.multi: return asyncio.run(
return asyncio.run( benchmark(
benchmark( backend=backend,
backend=backend, api_url=api_url,
api_url=api_url, base_url=base_url,
base_url=base_url, model_id=model_id,
model_id=model_id, tokenizer=tokenizer,
tokenizer=tokenizer, input_requests=input_requests,
input_requests=input_requests, request_rate=args.request_rate,
request_rate=args.request_rate, max_concurrency=args.max_concurrency,
max_concurrency=args.max_concurrency, disable_tqdm=args.disable_tqdm,
disable_tqdm=args.disable_tqdm, lora_name=args.lora_name,
lora_name=args.lora_name, extra_request_body=extra_request_body,
extra_request_body=extra_request_body, profile=args.profile,
profile=args.profile, pd_seperated=args.pd_seperated,
)
) )
else: )
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
request_rates = parse_request_rate_range(args.request_rate_range)
for rate in request_rates:
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
)
def set_ulimit(target_soft_limit=65535): def set_ulimit(target_soft_limit=65535):
...@@ -1428,17 +1459,6 @@ if __name__ == "__main__": ...@@ -1428,17 +1459,6 @@ if __name__ == "__main__":
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.", "if the server is not processing requests fast enough to keep up.",
) )
parser.add_argument(
"--multi",
action="store_true",
help="Use request rate range rather than single value.",
)
parser.add_argument(
"--request-rate-range",
type=str,
default="2,34,2",
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.") parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
parser.add_argument( parser.add_argument(
"--disable-tqdm", "--disable-tqdm",
...@@ -1485,6 +1505,17 @@ if __name__ == "__main__": ...@@ -1485,6 +1505,17 @@ if __name__ == "__main__":
default=None, default=None,
help="The name of LoRA adapter", help="The name of LoRA adapter",
) )
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
parser.add_argument(
"--pd-seperated",
action="store_true",
help="Benchmark PD disaggregation server",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments") group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument( group.add_argument(
......
...@@ -34,11 +34,9 @@ class GlobalConfig: ...@@ -34,11 +34,9 @@ class GlobalConfig:
self.skip_special_tokens_in_output = True self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True self.spaces_between_special_tokens_in_out = True
# Interpreter optimization configs # Language frontend interpreter optimization configs
self.enable_precache_with_tracing = True self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True self.enable_parallel_encoding = True
self.enable_flashinfer_mla = False
global_config = GlobalConfig() global_config = GlobalConfig()
...@@ -329,7 +329,12 @@ class RuntimeEndpoint(BaseBackend): ...@@ -329,7 +329,12 @@ class RuntimeEndpoint(BaseBackend):
def compute_normalized_prompt_logprobs(input_logprobs): def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]] values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values) try:
return sum(values) / len(values)
except TypeError:
print(f"{input_logprobs=}", flush=True)
print(f"{input_logprobs[0]=}", flush=True)
exit(-1)
class Runtime: class Runtime:
......
...@@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum): ...@@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes" BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral" MISTRAL = "mistral"
LAYERED = "layered" LAYERED = "layered"
JAX = "jax"
@dataclass @dataclass
...@@ -42,13 +43,15 @@ class LoadConfig: ...@@ -42,13 +43,15 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model. ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's Default to "original/**/*" to avoid repeated loading of llama's
checkpoints. checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
""" """
load_format: Union[str, LoadFormat] = LoadFormat.AUTO load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
......
...@@ -44,6 +44,7 @@ class ModelConfig: ...@@ -44,6 +44,7 @@ class ModelConfig:
is_embedding: Optional[bool] = None, is_embedding: Optional[bool] = None,
dtype: str = "auto", dtype: str = "auto",
quantization: Optional[str] = None, quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
) -> None: ) -> None:
self.model_path = model_path self.model_path = model_path
self.revision = revision self.revision = revision
...@@ -51,11 +52,16 @@ class ModelConfig: ...@@ -51,11 +52,16 @@ class ModelConfig:
# Parse args # Parse args
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config( self.hf_config = get_config(
model_path, model_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
model_override_args=self.model_override_args, model_override_args=self.model_override_args,
**kwargs,
) )
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
...@@ -64,6 +70,9 @@ class ModelConfig: ...@@ -64,6 +70,9 @@ class ModelConfig:
self.hf_config.architectures, is_embedding self.hf_config.architectures, is_embedding
) )
self.is_multimodal = is_multimodal_model(self.hf_config.architectures) self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
self.is_audio_model = is_audio_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
...@@ -71,7 +80,9 @@ class ModelConfig: ...@@ -71,7 +80,9 @@ class ModelConfig:
derived_context_len = get_context_length(self.hf_text_config) derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None: if context_length is not None:
if context_length > derived_context_len: if context_length > derived_context_len:
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"): if get_bool_env_var(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
):
logger.warning( logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors." f"This may lead to incorrect model outputs or CUDA errors."
...@@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]): ...@@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaQwenForCausalLM" in model_architectures or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures
or "Grok1VForCausalLM" in model_architectures
or "Grok1AForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures
or "Qwen2_5_VLForConditionalGeneration" in model_architectures or "Qwen2_5_VLForConditionalGeneration" in model_architectures
...@@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]): ...@@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]):
return False return False
def is_multimodal_gen_model(model_architectures: List[str]):
return False
def is_image_gen_model(model_architectures: List[str]):
return False
def is_audio_model(model_architectures: List[str]):
return False
def is_encoder_decoder_model(model_architectures: List[str]): def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures return "MllamaForConditionalGeneration" in model_architectures
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import json import json
import logging import logging
from typing import List, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from xgrammar import ( from xgrammar import (
...@@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200 ...@@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200
class XGrammarGrammar(BaseGrammarObject): class XGrammarGrammar(BaseGrammarObject):
def __init__( def __init__(
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar self,
matcher: GrammarMatcher,
vocab_size: int,
ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]],
) -> None: ) -> None:
self.matcher = matcher self.matcher = matcher
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.ctx = ctx self.ctx = ctx
self.override_stop_tokens = override_stop_tokens
self.finished = False self.finished = False
def accept_token(self, token: int): def accept_token(self, token: int):
...@@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
apply_token_bitmask_inplace(logits, vocab_mask) apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self): def copy(self):
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) matcher = GrammarMatcher(
return XGrammarGrammar(matcher, self.vocab_size, self.ctx) self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
)
class XGrammarGrammarBackend(BaseGrammarBackend): class XGrammarGrammarBackend(BaseGrammarBackend):
...@@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tokenizer_info = TokenizerInfo.from_huggingface( tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size tokenizer, vocab_size=vocab_size
) )
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
...@@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
raise ValueError(f"Invalid key_type: {key_type}") raise ValueError(f"Invalid key_type: {key_type}")
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, ctx) return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
def reset(self): def reset(self):
if self.grammar_compiler: if self.grammar_compiler:
......
...@@ -121,6 +121,7 @@ class Engine: ...@@ -121,6 +121,7 @@ class Engine:
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None, lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
...@@ -142,6 +143,7 @@ class Engine: ...@@ -142,6 +143,7 @@ class Engine:
return_logprob=return_logprob, return_logprob=return_logprob,
logprob_start_len=logprob_start_len, logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num, top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path, lora_path=lora_path,
modalities=modalities_list, modalities=modalities_list,
custom_logit_processor=custom_logit_processor, custom_logit_processor=custom_logit_processor,
...@@ -179,6 +181,7 @@ class Engine: ...@@ -179,6 +181,7 @@ class Engine:
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None, lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None,
stream: bool = False, stream: bool = False,
...@@ -195,6 +198,7 @@ class Engine: ...@@ -195,6 +198,7 @@ class Engine:
return_logprob=return_logprob, return_logprob=return_logprob,
logprob_start_len=logprob_start_len, logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num, top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path, lora_path=lora_path,
stream=stream, stream=stream,
custom_logit_processor=custom_logit_processor, custom_logit_processor=custom_logit_processor,
...@@ -226,15 +230,22 @@ class Engine: ...@@ -226,15 +230,22 @@ class Engine:
kill_process_tree(os.getpid(), include_parent=False) kill_process_tree(os.getpid(), include_parent=False)
def start_profile(self): def start_profile(self):
self.tokenizer_manager.start_profile() loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.start_profile())
def stop_profile(self): def stop_profile(self):
self.tokenizer_manager.stop_profile() self.tokenizer_manager.stop_profile()
def get_server_info(self): def get_server_info(self):
loop = asyncio.get_event_loop()
internal_states = loop.run_until_complete(
self.tokenizer_manager.get_internal_state()
)
return { return {
**dataclasses.asdict(self.tokenizer_manager.server_args), # server args **dataclasses.asdict(self.tokenizer_manager.server_args),
**self.scheduler_info, **self.scheduler_info,
**internal_states,
"version": __version__, "version": __version__,
} }
...@@ -323,6 +334,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -323,6 +334,7 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
# Set prometheus env vars # Set prometheus env vars
if server_args.enable_metrics: if server_args.enable_metrics:
...@@ -346,12 +358,23 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -346,12 +358,23 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
def sigchld_handler(signum, frame):
pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0:
logger.warning(
"Child process unexpectedly failed with an exit code %d. pid=%d",
exitcode,
pid,
)
signal.signal(signal.SIGCHLD, sigchld_handler)
# Register the signal handler. # Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens # The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree # This process then clean up the whole process tree
def sigquit_handler(signum, frame): def sigquit_handler(signum, frame):
logger.error( logger.error(
"Received sigquit from a child proces. It usually means the child failed." "Received sigquit from a child process. It usually means the child failed."
) )
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
......
...@@ -25,11 +25,14 @@ import os ...@@ -25,11 +25,14 @@ import os
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional from typing import AsyncIterator, Callable, Dict, Optional
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
from contextlib import asynccontextmanager
import numpy as np
import orjson import orjson
import requests import requests
import uvicorn import uvicorn
...@@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
OpenSessionReqInput, OpenSessionReqInput,
ParseFunctionCallReq, ParseFunctionCallReq,
ProfileReqInput,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
SetInternalStateReq,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
VertexGenerateReqInput, VertexGenerateReqInput,
...@@ -78,22 +83,13 @@ from sglang.srt.utils import ( ...@@ -78,22 +83,13 @@ from sglang.srt.utils import (
kill_process_tree, kill_process_tree,
set_uvicorn_logging_configs, set_uvicorn_logging_configs,
) )
from sglang.srt.warmup import execute_warmups
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
from sglang.version import __version__ from sglang.version import __version__
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Fast API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Store global states # Store global states
@dataclasses.dataclass @dataclasses.dataclass
...@@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState): ...@@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state _global_state = global_state
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
)
logger.info("Warmup ended")
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
# Fast API
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
##### Native API endpoints ##### ##### Native API endpoints #####
...@@ -123,24 +147,48 @@ async def health() -> Response: ...@@ -123,24 +147,48 @@ async def health() -> Response:
async def health_generate(request: Request) -> Response: async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token.""" """Check the health of the inference server by generating one token."""
sampling_params = {"max_new_tokens": 1, "temperature": 0.7} sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"HEALTH_CHECK_{time.time()}"
if _global_state.tokenizer_manager.is_generation: if _global_state.tokenizer_manager.is_image_gen:
raise NotImplementedError()
elif _global_state.tokenizer_manager.is_generation:
gri = GenerateReqInput( gri = GenerateReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False rid=rid,
input_ids=[0],
sampling_params=sampling_params,
log_metrics=False,
) )
else: else:
gri = EmbeddingReqInput( gri = EmbeddingReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
) )
try: async def gen():
async for _ in _global_state.tokenizer_manager.generate_request(gri, request): async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break break
return Response(status_code=200)
except Exception as e: tic = time.time()
logger.exception(e) task = asyncio.create_task(gen())
return Response(status_code=503) while time.time() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1)
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
return Response(status_code=200)
task.cancel()
tic_time = time.strftime("%H:%M:%S", time.localtime(tic))
last_receive_time = time.strftime(
"%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)
)
logger.error(
f"Health check failed. Server couldn't get a response from detokenizer for last "
f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. "
f"last_heartbeat time: {last_receive_time}"
)
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
return Response(status_code=503)
@app.get("/get_model_info") @app.get("/get_model_info")
...@@ -156,13 +204,21 @@ async def get_model_info(): ...@@ -156,13 +204,21 @@ async def get_model_info():
@app.get("/get_server_info") @app.get("/get_server_info")
async def get_server_info(): async def get_server_info():
internal_states = await _global_state.tokenizer_manager.get_internal_state()
return { return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args), **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info, **_global_state.scheduler_info,
**internal_states,
"version": __version__, "version": __version__,
} }
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
return res
# fastapi implicitly converts json in the request to obj (dataclass) # fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"]) @app.api_route("/generate", methods=["POST", "PUT"])
async def generate_request(obj: GenerateReqInput, request: Request): async def generate_request(obj: GenerateReqInput, request: Request):
...@@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) + b"\n\n" ) + b"\n\n"
except ValueError as e: except ValueError as e:
out = {"error": {"message": str(e)}} out = {"error": {"message": str(e)}}
logger.error(f"Error: {e}")
yield b"data: " + orjson.dumps( yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n" ) + b"\n\n"
...@@ -236,9 +293,14 @@ async def flush_cache(): ...@@ -236,9 +293,14 @@ async def flush_cache():
@app.api_route("/start_profile", methods=["GET", "POST"]) @app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async(): async def start_profile_async(obj: Optional[ProfileReqInput] = None):
"""Start profiling.""" """Start profiling."""
_global_state.tokenizer_manager.start_profile() if obj is None:
obj = ProfileReqInput()
await _global_state.tokenizer_manager.start_profile(
obj.output_dir, obj.num_steps, obj.activities
)
return Response( return Response(
content="Start profiling.\n", content="Start profiling.\n",
status_code=200, status_code=200,
...@@ -257,11 +319,15 @@ async def stop_profile_async(): ...@@ -257,11 +319,15 @@ async def stop_profile_async():
@app.post("/update_weights_from_disk") @app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk in-place without re-launching the server.""" """Update the weights from disk inplace without re-launching the server."""
success, message = await _global_state.tokenizer_manager.update_weights_from_disk( success, message, num_paused_requests = (
obj, request await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
) )
content = {"success": success, "message": message} content = {
"success": success,
"message": message,
"num_paused_requests": num_paused_requests,
}
if success: if success:
return ORJSONResponse( return ORJSONResponse(
content, content,
...@@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): ...@@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
async def release_memory_occupation( async def release_memory_occupation(
obj: ReleaseMemoryOccupationReqInput, request: Request obj: ReleaseMemoryOccupationReqInput, request: Request
): ):
"""Release GPU occupation temporarily""" """Release GPU memory occupation temporarily."""
try: try:
await _global_state.tokenizer_manager.release_memory_occupation(obj, request) await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
except Exception as e: except Exception as e:
...@@ -334,7 +400,7 @@ async def release_memory_occupation( ...@@ -334,7 +400,7 @@ async def release_memory_occupation(
async def resume_memory_occupation( async def resume_memory_occupation(
obj: ResumeMemoryOccupationReqInput, request: Request obj: ResumeMemoryOccupationReqInput, request: Request
): ):
"""Resume GPU occupation""" """Resume GPU memory occupation."""
try: try:
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
except Exception as e: except Exception as e:
...@@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request): ...@@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
@app.api_route("/close_session", methods=["GET", "POST"]) @app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request): async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session""" """Close the session."""
try: try:
await _global_state.tokenizer_manager.close_session(obj, request) await _global_state.tokenizer_manager.close_session(obj, request)
return Response(status_code=200) return Response(status_code=200)
...@@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request): ...@@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
@app.api_route("/configure_logging", methods=["GET", "POST"]) @app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request): async def configure_logging(obj: ConfigureLoggingReq, request: Request):
"""Close the session""" """Configure the request logging options."""
_global_state.tokenizer_manager.configure_logging(obj) _global_state.tokenizer_manager.configure_logging(obj)
return Response(status_code=200) return Response(status_code=200)
...@@ -511,6 +577,7 @@ def _create_error_response(e): ...@@ -511,6 +577,7 @@ def _create_error_response(e):
def launch_server( def launch_server(
server_args: ServerArgs, server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
launch_callback: Optional[Callable[[], None]] = None,
): ):
""" """
Launch SRT (SGLang Runtime) Server. Launch SRT (SGLang Runtime) Server.
...@@ -544,21 +611,23 @@ def launch_server( ...@@ -544,21 +611,23 @@ def launch_server(
add_prometheus_middleware(app) add_prometheus_middleware(app)
enable_func_timer() enable_func_timer()
# Send a warmup request # Send a warmup request - we will create the thread launch it
t = threading.Thread( # in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup, target=_wait_and_warmup,
args=( args=(
server_args, server_args,
pipe_finish_writer, pipe_finish_writer,
_global_state.tokenizer_manager.image_token_id, _global_state.tokenizer_manager.image_token_id,
launch_callback,
), ),
) )
t.start() app.warmup_thread = warmup_thread
try: try:
# Update logging configs # Update logging configs
set_uvicorn_logging_configs() set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests # Listen for HTTP requests
uvicorn.run( uvicorn.run(
app, app,
...@@ -569,10 +638,15 @@ def launch_server( ...@@ -569,10 +638,15 @@ def launch_server(
loop="uvloop", loop="uvloop",
) )
finally: finally:
t.join() warmup_thread.join()
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): def _wait_and_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
headers = {} headers = {}
url = server_args.url() url = server_args.url()
if server_args.api_key: if server_args.api_key:
...@@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): ...@@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
else: else:
json_data["text"] = "The capital city of France is" json_data["text"] = "The capital city of France is"
# Debug dumping
if server_args.debug_tensor_dump_input_file:
json_data.pop("text", None)
json_data["input_ids"] = np.load(
server_args.debug_tensor_dump_input_file
).tolist()
json_data["sampling_params"]["max_new_tokens"] = 0
try: try:
for _ in range(server_args.dp_size): for i in range(server_args.dp_size):
res = requests.post( res = requests.post(
url + request_name, url + request_name,
json=json_data, json=json_data,
...@@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): ...@@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
if server_args.delete_ckpt_after_loading: if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path) delete_directory(server_args.model_path)
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())
if launch_callback is not None:
launch_callback()
...@@ -60,6 +60,7 @@ class VerlEngine: ...@@ -60,6 +60,7 @@ class VerlEngine:
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None, lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None,
) -> Dict: ) -> Dict:
...@@ -76,6 +77,7 @@ class VerlEngine: ...@@ -76,6 +77,7 @@ class VerlEngine:
return_logprob=return_logprob, return_logprob=return_logprob,
logprob_start_len=logprob_start_len, logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num, top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path, lora_path=lora_path,
custom_logit_processor=custom_logit_processor, custom_logit_processor=custom_logit_processor,
) )
......
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Union
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
class AttentionBackend(ABC): class AttentionBackend(ABC):
...@@ -31,7 +31,7 @@ class AttentionBackend(ABC): ...@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
"""Init the metadata for a forward pass for capturing a cuda graph.""" """Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
...@@ -44,7 +44,7 @@ class AttentionBackend(ABC): ...@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
"""Init the metadata for a forward pass for replying a cuda graph.""" """Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
...@@ -64,7 +64,14 @@ class AttentionBackend(ABC): ...@@ -64,7 +64,14 @@ class AttentionBackend(ABC):
): ):
"""Run forward on an attention layer.""" """Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache) return self.forward_decode(
q,
k,
v,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
)
else: else:
return self.forward_extend( return self.forward_extend(
q, q,
...@@ -72,7 +79,7 @@ class AttentionBackend(ABC): ...@@ -72,7 +79,7 @@ class AttentionBackend(ABC):
v, v,
layer, layer,
forward_batch, forward_batch,
save_kv_cache, save_kv_cache=save_kv_cache,
) )
def forward_decode( def forward_decode(
......
...@@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend):
model_runner: ModelRunner, model_runner: ModelRunner,
skip_prefill: bool = False, skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None, kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
): ):
super().__init__() super().__init__()
...@@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend):
assert self.num_wrappers == 1 assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf] self.kv_indptr = [kv_indptr_buf]
self.kv_last_page_len = torch.ones( if kv_last_page_len_buf is None:
(max_bs,), dtype=torch.int32, device=model_runner.device self.kv_last_page_len = torch.ones(
) (max_bs,), dtype=torch.int32, device=model_runner.device
)
else:
assert self.num_wrappers == 1
self.kv_last_page_len = kv_last_page_len_buf
self.qo_indptr = [ self.qo_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers) for _ in range(self.num_wrappers)
...@@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend: ...@@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend:
dtype=torch.int32, dtype=torch.int32,
device=model_runner.device, device=model_runner.device,
) )
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.attn_backends = [] self.attn_backends = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends.append( self.attn_backends.append(
...@@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend:
model_runner, model_runner,
skip_prefill=True, skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i], kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=self.kv_last_page_len,
) )
) )
self.max_context_len = self.attn_backends[0].max_context_len self.max_context_len = self.attn_backends[0].max_context_len
......
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Union
import torch import torch
import triton import triton
...@@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo ...@@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
class TritonAttnBackend(AttentionBackend): class TritonAttnBackend(AttentionBackend):
...@@ -232,7 +232,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -232,7 +232,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
assert encoder_lens is None, "Not supported" assert encoder_lens is None, "Not supported"
...@@ -310,7 +310,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -310,7 +310,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
# NOTE: encoder_lens expected to be zeros or None # NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment