Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
...@@ -57,7 +57,7 @@ dependencies = [ ...@@ -57,7 +57,7 @@ dependencies = [
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.24", "xgrammar==0.1.24",
"sgl-kernel==0.3.13", "sgl-kernel==0.3.11",
"torch==2.8.0", "torch==2.8.0",
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
...@@ -67,7 +67,7 @@ dependencies = [ ...@@ -67,7 +67,7 @@ dependencies = [
"tiktoken", "tiktoken",
"anthropic>=0.20.0", "anthropic>=0.20.0",
"torch_memory_saver==0.0.8", "torch_memory_saver==0.0.8",
"nvidia-cutlass-dsl==4.2.1", "nvidia-cutlass-dsl==4.2.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
...@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"] ...@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
"srt/layers/moe/fused_moe_triton/configs/*/*.json", "srt/layers/moe/fused_moe_triton/configs/*/*.json",
"srt/layers/quantization/configs/*.json", "srt/layers/quantization/configs/*.json",
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp", "srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
"srt/speculative/cpp_ngram/*.cpp", "srt/speculative/cpp_lookahead/*.cpp",
"srt/speculative/cpp_ngram/*.h", "srt/speculative/cpp_lookahead/*.h",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
......
...@@ -65,23 +65,23 @@ tracing = [ ...@@ -65,23 +65,23 @@ tracing = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.3.13", "sgl-kernel==0.3.11",
"torch==2.8.0", "torch==2.8.0",
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
"cuda-python", "cuda-python",
"flashinfer_python==0.4.0rc1", "flashinfer_python==0.3.1",
] ]
blackwell = [ blackwell = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.3.13", "sgl-kernel==0.3.11",
"torch==2.8.0", "torch==2.8.0",
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
"cuda-python", "cuda-python",
"flashinfer_python==0.4.0rc1", "flashinfer_python==0.3.1",
"nvidia-cutlass-dsl==4.2.1", "nvidia-cutlass-dsl==4.2.0",
] ]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
......
...@@ -443,9 +443,11 @@ def latency_test_run_once( ...@@ -443,9 +443,11 @@ def latency_test_run_once(
if profile: if profile:
profiler.stop() profiler.stop()
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz" profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
_save_profile_trace_results(profiler, trace_filename) _save_profile_trace_results(profiler, profile_filename)
rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}") rank_print(
f"torch profiler chrome trace for prefill saved to {profile_filename}"
)
# Decode # Decode
decode_latencies = [] decode_latencies = []
...@@ -477,10 +479,10 @@ def latency_test_run_once( ...@@ -477,10 +479,10 @@ def latency_test_run_once(
if profile and i == output_len / 2: if profile and i == output_len / 2:
profiler.stop() profiler.stop()
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz" profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
_save_profile_trace_results(profiler, trace_filename) _save_profile_trace_results(profiler, profile_filename)
rank_print( rank_print(
f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}" f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
) )
# Record decode timing from 2nd output # Record decode timing from 2nd output
......
...@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B -- ...@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
""" """
import argparse import argparse
...@@ -20,17 +19,12 @@ import multiprocessing ...@@ -20,17 +19,12 @@ import multiprocessing
import os import os
import random import random
import time import time
from typing import List, Optional, Tuple from typing import List, Tuple
import numpy as np import numpy as np
import requests import requests
from pydantic import BaseModel
from sglang.bench_serving import ( from sglang.bench_serving import get_tokenizer, sample_random_requests
get_tokenizer,
sample_mmmu_requests,
sample_random_requests,
)
from sglang.profiler import run_profile from sglang.profiler import run_profile
from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree ...@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary from sglang.test.test_utils import is_in_ci, write_github_step_summary
class ProfileLinks(BaseModel):
"""Pydantic model for profile trace links."""
extend: Optional[str] = None
decode: Optional[str] = None
class BenchmarkResult(BaseModel):
"""Pydantic model for benchmark results table data, for a single isl and osl"""
model_path: str
run_name: str
batch_size: int
input_len: int
output_len: int
latency: float
ttft: float
input_throughput: float
output_throughput: float
overall_throughput: float
last_gen_throughput: float
acc_length: Optional[float] = None
profile_links: Optional[ProfileLinks] = None
@staticmethod
def help_str() -> str:
return f"""
Note: To view the traces through perfetto-ui, please:
1. open with Google Chrome
2. allow popup
"""
def to_markdown_row(
self, trace_dir, base_url: str = "", relay_base: str = ""
) -> str:
"""Convert this benchmark result to a markdown table row."""
# Calculate costs (assuming H100 pricing for now)
hourly_cost_per_gpu = 2 # $2/hour for one H100
hourly_cost = hourly_cost_per_gpu * 1 # Assuming tp_size = 1 for simplicity
input_util = 0.7
accept_length = (
round(self.acc_length, 2) if self.acc_length is not None else "n/a"
)
itl = 1 / (self.output_throughput / self.batch_size) * 1000
input_cost = 1e6 / (self.input_throughput * input_util) / 3600 * hourly_cost
output_cost = 1e6 / self.output_throughput / 3600 * hourly_cost
def get_perfetto_relay_link_from_trace_file(trace_file: str):
import os
from urllib.parse import quote
rel_path = os.path.relpath(trace_file, trace_dir)
raw_file_link = f"{base_url}/{rel_path}"
relay_link = (
f"{relay_base}?src={quote(raw_file_link, safe='')}"
if relay_base and quote
else raw_file_link
)
return relay_link
# Handle profile links
profile_link = "NA | NA"
if self.profile_links:
if self.profile_links.extend or self.profile_links.decode:
# Create a combined link or use the first available one
trace_files = [self.profile_links.extend, self.profile_links.decode]
trace_files_relay_links = [
f"[trace]({get_perfetto_relay_link_from_trace_file(trace_file)})"
for trace_file in trace_files
]
profile_link = " | ".join(trace_files_relay_links)
# Build the row
return f"| {self.batch_size} | {self.input_len} | {self.latency:.2f} | {self.input_throughput:.2f} | {self.output_throughput:.2f} | {accept_length} | {itl:.2f} | {input_cost:.2f} | {output_cost:.2f} | {profile_link} |\n"
@classmethod
def generate_markdown_report(
cls, trace_dir, results: List["BenchmarkResult"]
) -> str:
"""Generate a markdown report from a list of BenchmarkResult object from a single run."""
import os
summary = f"### {results[0].model_path}\n"
# summary += (
# f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
# )
summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|\n"
summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |\n"
# all results should share the same isl & osl
for result in results:
base_url = os.getenv("TRACE_BASE_URL", "").rstrip("/")
relay_base = os.getenv("PERFETTO_RELAY_URL", "").rstrip("/")
relay_base = "https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
# base_url = "https://github.com/sgl-project/ci-data/traces"
summary += result.to_markdown_row(trace_dir, base_url, relay_base)
return summary
@dataclasses.dataclass @dataclasses.dataclass
class BenchArgs: class BenchArgs:
run_name: str = "default" run_name: str = "default"
...@@ -158,12 +50,8 @@ class BenchArgs: ...@@ -158,12 +50,8 @@ class BenchArgs:
profile: bool = False profile: bool = False
profile_steps: int = 3 profile_steps: int = 3
profile_by_stage: bool = False profile_by_stage: bool = False
profile_filename_prefix: str = None
append_to_github_summary: bool = True
dataset_path: str = "" dataset_path: str = ""
parallel_batch: bool = False parallel_batch: bool = False
dataset_name: str = "random"
output_path: Optional[str] = None
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -179,13 +67,6 @@ class BenchArgs: ...@@ -179,13 +67,6 @@ class BenchArgs:
"--output-len", type=int, nargs="+", default=BenchArgs.output_len "--output-len", type=int, nargs="+", default=BenchArgs.output_len
) )
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument(
"--dataset-name",
type=str,
default=BenchArgs.dataset_name,
choices=["mmmu", "random"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--return-logprob", action="store_true") parser.add_argument("--return-logprob", action="store_true")
parser.add_argument( parser.add_argument(
"--client-stream-interval", "--client-stream-interval",
...@@ -215,36 +96,14 @@ class BenchArgs: ...@@ -215,36 +96,14 @@ class BenchArgs:
help="Path to the dataset.", help="Path to the dataset.",
) )
parser.add_argument("--parallel-batch", action="store_true") parser.add_argument("--parallel-batch", action="store_true")
parser.add_argument(
"--profile-filename-prefix",
type=str,
default=BenchArgs.profile_filename_prefix,
)
parser.add_argument(
"--no-append-to-github-summary",
action="store_false",
dest="append_to_github_summary",
help="Disable appending the output of this run to github ci summary",
)
parser.add_argument(
"--output-path",
type=str,
default=BenchArgs.output_path,
help="Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types. # use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
kwargs = {} return cls(
for attr, attr_type in attrs: **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
val = getattr(args, attr) )
if attr_type is type(None):
kwargs[attr] = val
else:
kwargs[attr] = attr_type(val)
return cls(**kwargs)
def launch_server_internal(server_args): def launch_server_internal(server_args):
...@@ -289,25 +148,13 @@ def run_one_case( ...@@ -289,25 +148,13 @@ def run_one_case(
run_name: str, run_name: str,
result_filename: str, result_filename: str,
tokenizer, tokenizer,
dataset_name="",
profile: bool = False, profile: bool = False,
profile_steps: int = 3, profile_steps: int = 3,
profile_by_stage: bool = False, profile_by_stage: bool = False,
profile_filename_prefix: str = None,
dataset_path: str = "", dataset_path: str = "",
parallel_batch: bool = False, parallel_batch: bool = False,
): ):
requests.post(url + "/flush_cache") requests.post(url + "/flush_cache")
# TODO: reuse bench_serving.get_dataset ?
if dataset_name == "mmmu":
input_requests = sample_mmmu_requests(
num_requests=batch_size,
tokenizer=tokenizer,
fixed_output_len=output_len,
apply_chat_template=True,
random_sample=False,
)
elif dataset_name == "random":
input_requests = sample_random_requests( input_requests = sample_random_requests(
input_len=input_len, input_len=input_len,
output_len=output_len, output_len=output_len,
...@@ -334,22 +181,15 @@ def run_one_case( ...@@ -334,22 +181,15 @@ def run_one_case(
profile_link = None profile_link = None
if profile: if profile:
output_dir, profile_name = None, None
if profile_filename_prefix:
output_dir = os.path.dirname(profile_filename_prefix)
profile_name = os.path.basename(profile_filename_prefix)
profile_link: str = run_profile( profile_link: str = run_profile(
url, url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
profile_steps,
["CPU", "GPU"],
output_dir,
profile_name,
profile_by_stage,
) )
tic = time.perf_counter() tic = time.perf_counter()
response = requests.post(
payload = { url + "/generate",
json={
"input_ids": [req.prompt for req in input_requests],
"sampling_params": { "sampling_params": {
"temperature": temperature, "temperature": temperature,
"max_new_tokens": output_len, "max_new_tokens": output_len,
...@@ -360,22 +200,7 @@ def run_one_case( ...@@ -360,22 +200,7 @@ def run_one_case(
"return_logprob": return_logprob, "return_logprob": return_logprob,
"stream": True, "stream": True,
**({"parallel_batch": parallel_batch} if parallel_batch else {}), **({"parallel_batch": parallel_batch} if parallel_batch else {}),
} },
if dataset_name == "mmmu":
# vlm
input_ids = []
for input_req in input_requests:
input_ids += [tokenizer.encode(input_req.prompt)]
payload["image_data"] = [req.image_data for req in input_requests]
else:
input_ids = [req.prompt for req in input_requests]
payload["input_ids"] = input_ids
response = requests.post(
url + "/generate",
json=payload,
stream=True, stream=True,
) )
...@@ -439,100 +264,10 @@ def run_one_case( ...@@ -439,100 +264,10 @@ def run_one_case(
overall_throughput, overall_throughput,
last_gen_throughput, last_gen_throughput,
acc_length, acc_length,
profile_link, profile_link if profile else None,
) )
def save_results_as_json(result: List[Tuple], bench_args: BenchArgs, model: str):
"""Save benchmark results as JSON using Pydantic models."""
json_results = []
# Generate all parameter combinations to match with results
param_combinations = list(
itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
)
)
for i, (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
profile_link,
) in enumerate(result):
# Get the corresponding parameters for this result
bs, input_len, output_len = param_combinations[i]
# Parse profile links if available
profile_links = None
if profile_link:
profile_links = parse_profile_links(
profile_link, batch_size, input_len, output_len
)
benchmark_result = BenchmarkResult(
model_path=model,
run_name=bench_args.run_name,
batch_size=batch_size,
input_len=input_len,
output_len=output_len,
latency=latency,
ttft=ttft,
input_throughput=input_throughput,
output_throughput=output_throughput,
overall_throughput=overall_throughput,
last_gen_throughput=last_gen_throughput,
acc_length=acc_length,
profile_links=profile_links,
)
json_results.append(benchmark_result.model_dump())
# Save to JSON file
with open(bench_args.output_path, "w", encoding="utf-8") as f:
json.dump(json_results, f, indent=2, ensure_ascii=False)
print(f"Results saved as JSON to {bench_args.output_path}")
def parse_profile_links(
profile_dir: str, batch_size: int, input_len: int, output_len: int
) -> Optional[ProfileLinks]:
"""Parse profile directory to extract extend and decode trace file links."""
if not profile_dir or not os.path.exists(profile_dir):
return None
extend_link = None
decode_link = None
# Look for extend/prefill trace files
for file in os.listdir(profile_dir):
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
if "extend" in file.lower() or "prefill" in file.lower():
extend_link = os.path.join(profile_dir, file)
elif "decode" in file.lower():
decode_link = os.path.join(profile_dir, file)
# If no specific extend/decode files found, try to find files with batch/input/output info
if not extend_link or not decode_link:
for file in os.listdir(profile_dir):
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
if f"_batch{batch_size}_input{input_len}_output{output_len}_" in file:
if "prefill" in file.lower() or "extend" in file.lower():
extend_link = os.path.join(profile_dir, file)
elif "decode" in file.lower():
decode_link = os.path.join(profile_dir, file)
if extend_link or decode_link:
return ProfileLinks(extend=extend_link, decode=decode_link)
return None
def get_report_summary( def get_report_summary(
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
): ):
...@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
return_logprob=bench_args.return_logprob, return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval, stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage, input_len_step_percentage=bench_args.input_len_step_percentage,
dataset_name=bench_args.dataset_name,
run_name="", run_name="",
result_filename="", result_filename="",
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
stream_interval=bench_args.client_stream_interval, stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage, input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name, run_name=bench_args.run_name,
dataset_name=bench_args.dataset_name,
result_filename=bench_args.result_filename, result_filename=bench_args.result_filename,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_path=bench_args.dataset_path, dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch, parallel_batch=bench_args.parallel_batch,
profile_filename_prefix=bench_args.profile_filename_prefix,
) )
) )
...@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
run_name=bench_args.run_name, run_name=bench_args.run_name,
result_filename=bench_args.result_filename, result_filename=bench_args.result_filename,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_name=bench_args.dataset_name,
profile=bench_args.profile, profile=bench_args.profile,
profile_steps=bench_args.profile_steps, profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage, profile_by_stage=bench_args.profile_by_stage,
dataset_path=bench_args.dataset_path, dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch, parallel_batch=bench_args.parallel_batch,
profile_filename_prefix=bench_args.profile_filename_prefix,
)[-1], )[-1],
) )
) )
...@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print(f"\nResults are saved to {bench_args.result_filename}") print(f"\nResults are saved to {bench_args.result_filename}")
# Save results as JSON if output_path is specified
if bench_args.output_path:
save_results_as_json(result, bench_args, model=server_args.model_path)
if not bench_args.show_report: if not bench_args.show_report:
return return
summary = get_report_summary(result, server_args, bench_args) summary = get_report_summary(result, server_args, bench_args)
print(summary)
if is_in_ci() and bench_args.append_to_github_summary: if is_in_ci():
write_github_step_summary(summary) write_github_step_summary(summary)
......
...@@ -208,10 +208,6 @@ async def async_request_openai_completions( ...@@ -208,10 +208,6 @@ async def async_request_openai_completions(
"ignore_eos": not args.disable_ignore_eos, "ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body, **request_func_input.extra_request_body,
} }
if request_func_input.image_data:
payload.update({"image_data": request_func_input.image_data})
headers = get_auth_headers() headers = get_auth_headers()
output = RequestFuncOutput.init_new(request_func_input) output = RequestFuncOutput.init_new(request_func_input)
...@@ -1763,9 +1759,7 @@ async def benchmark( ...@@ -1763,9 +1759,7 @@ async def benchmark(
pbar.close() pbar.close()
if "sglang" in backend: if "sglang" in backend:
server_info = requests.get( server_info = requests.get(base_url + "/get_server_info")
base_url + "/get_server_info", headers=get_auth_headers()
)
if server_info.status_code == 200: if server_info.status_code == 200:
server_info_json = server_info.json() server_info_json = server_info.json()
if "decode" in server_info_json: if "decode" in server_info_json:
......
...@@ -124,8 +124,6 @@ class Envs: ...@@ -124,8 +124,6 @@ class Envs:
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False) SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False) SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False) SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
# Model Parallel # Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True) SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
......
...@@ -37,8 +37,8 @@ class GlobalConfig: ...@@ -37,8 +37,8 @@ class GlobalConfig:
) )
# Runtime constants: others # Runtime constants: others
self.retract_decode_steps = 20 self.retract_decode_steps = 20
self.flashinfer_workspace_size = int( self.flashinfer_workspace_size = os.environ.get(
os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024) "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
) )
# Output tokenization configs # Output tokenization configs
......
...@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server ...@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
MOVE_ENVS_WARN = """
########################################################################
# For contributors and developers: #
# Please move environment variable definitions to sglang.srt.environ #
# using the following pattern: #
# SGLANG_XXX = EnvBool(False) #
# #
########################################################################
"""
if __name__ == "__main__": if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:]) server_args = prepare_server_args(sys.argv[1:])
from sglang.srt.server_args import print_deprecated_warning
print_deprecated_warning(MOVE_ENVS_WARN)
try: try:
launch_server(server_args) launch_server(server_args)
finally: finally:
......
...@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum): ...@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
JAX = "jax" JAX = "jax"
REMOTE = "remote" REMOTE = "remote"
REMOTE_INSTANCE = "remote_instance" REMOTE_INSTANCE = "remote_instance"
RDMA = "rdma"
LOCAL_CACHED = "local_cached"
@dataclass @dataclass
...@@ -49,7 +47,6 @@ class LoadConfig: ...@@ -49,7 +47,6 @@ class LoadConfig:
checkpoints. checkpoints.
decryption_key_file: If set, decrypts the output files with a password read decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2). from this file (after PBKDF2).
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
""" """
load_format: Union[str, LoadFormat] = LoadFormat.AUTO load_format: Union[str, LoadFormat] = LoadFormat.AUTO
...@@ -57,11 +54,6 @@ class LoadConfig: ...@@ -57,11 +54,6 @@ class LoadConfig:
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None decryption_key_file: Optional[str] = None
decrypt_max_concurrency: int = -1
tp_rank: Optional[int] = None
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
......
...@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip, retry from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.utils import is_in_ci from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,6 +48,30 @@ class ModelImpl(str, Enum): ...@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
TRANSFORMERS = "transformers" TRANSFORMERS = "transformers"
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
and getattr(config, "index_topk", None) is not None
)
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_head_dim
def get_nsa_index_topk(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_topk
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_n_heads
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
...@@ -64,20 +88,35 @@ class ModelConfig: ...@@ -64,20 +88,35 @@ class ModelConfig:
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None, hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None: ) -> None:
# Parse args # Parse args
self.model_path = model_path self.model_path = model_path
self.revision = revision self.revision = revision
self.quantization = quantization self.quantization = quantization
self.is_draft_model = is_draft_model
self.model_impl = model_impl self.model_impl = model_impl
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)
# Get hf config self.maybe_pull_model_tokenizer_from_remote()
self._maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
kwargs = {} kwargs = {}
if override_config_file and override_config_file.strip(): if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip() kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config( self.hf_config = get_config(
self.model_path, self.model_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -85,7 +124,7 @@ class ModelConfig: ...@@ -85,7 +124,7 @@ class ModelConfig:
model_override_args=self.model_override_args, model_override_args=self.model_override_args,
**kwargs, **kwargs,
) )
self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_generation_config = get_generation_config( self.hf_generation_config = get_generation_config(
self.model_path, self.model_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -93,25 +132,7 @@ class ModelConfig: ...@@ -93,25 +132,7 @@ class ModelConfig:
**kwargs, **kwargs,
) )
# Set enable_multimodal self.hf_text_config = get_hf_text_config(self.hf_config)
if enable_multimodal is None:
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False
logger.info(
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
)
else:
enable_multimodal = True
# Config draft model
self._config_draft_model()
# Check model type
self.attention_chunk_size = getattr( self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None self.hf_text_config, "attention_chunk_size", None
) )
...@@ -127,70 +148,20 @@ class ModelConfig: ...@@ -127,70 +148,20 @@ class ModelConfig:
self.hf_config.architectures, self.hf_text_config.num_hidden_layers self.hf_config.architectures, self.hf_text_config.num_hidden_layers
) )
) )
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length and model shapes
self._derive_context_length(context_length)
self._derive_model_shapes()
# Verify quantization
self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes if enable_multimodal is None:
self.hf_eos_token_id = self._get_hf_eos_token_id() mm_disabled_models = [
"Gemma3ForConditionalGeneration",
# multimodal "Llama4ForConditionalGeneration",
self.image_token_id = getattr( "Step3VLForConditionalGeneration",
self.hf_config, "image_token_id", None ]
) or getattr(self.hf_config, "image_token_index", None) if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False
@staticmethod logger.info(
def from_server_args( f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
server_args: ServerArgs,
model_path: str = None,
model_revision: str = None,
**kwargs,
):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=model_revision or server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
**kwargs,
) )
else:
def _config_draft_model(self): enable_multimodal = True
is_draft_model = self.is_draft_model
if ( if (
is_draft_model is_draft_model
...@@ -225,10 +196,31 @@ class ModelConfig: ...@@ -225,10 +196,31 @@ class ModelConfig:
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP" self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
self.hf_config.num_nextn_predict_layers = 1 self.hf_config.num_nextn_predict_layers = 1
def _derive_context_length(self, context_length: int): # Check model type
is_draft_model = self.is_draft_model self.is_generation = is_generation_model(
derived_context_len = get_context_length(self.hf_text_config) self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None: if context_length is not None:
if context_length > derived_context_len: if context_length > derived_context_len:
reason = "Target model's" if is_draft_model else "User-specified" reason = "Target model's" if is_draft_model else "User-specified"
...@@ -242,11 +234,6 @@ class ModelConfig: ...@@ -242,11 +234,6 @@ class ModelConfig:
): ):
logger.warning(msg) logger.warning(msg)
self.context_len = context_length self.context_len = context_length
if is_draft_model:
self.hf_text_config.max_position_embeddings = context_length
logger.warning(
f"Overriding the draft model's max_position_embeddings to {context_length}."
)
else: else:
raise ValueError( raise ValueError(
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
...@@ -256,10 +243,6 @@ class ModelConfig: ...@@ -256,10 +243,6 @@ class ModelConfig:
else: else:
self.context_len = derived_context_len self.context_len = derived_context_len
# Transfer context_len to HuggingFace config so models can access it
self.hf_config.context_len = self.context_len
def _derive_model_shapes(self):
# Unify the config keys for hf_text_config # Unify the config keys for hf_text_config
self.head_dim = getattr( self.head_dim = getattr(
self.hf_text_config, self.hf_text_config,
...@@ -270,6 +253,7 @@ class ModelConfig: ...@@ -270,6 +253,7 @@ class ModelConfig:
# FIXME: temporary special judge for MLA architecture # FIXME: temporary special judge for MLA architecture
if ( if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures "DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures or "LongcatFlashForCausalLM" in self.hf_config.architectures
...@@ -282,6 +266,11 @@ class ModelConfig: ...@@ -282,6 +266,11 @@ class ModelConfig:
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim self.v_head_dim = self.hf_config.v_head_dim
self.index_head_dim = (
get_nsa_index_head_dim(self.hf_config)
if is_deepseek_nsa(self.hf_config)
else None
)
# Handle rope scaling with yarn # Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
...@@ -354,6 +343,45 @@ class ModelConfig: ...@@ -354,6 +343,45 @@ class ModelConfig:
) )
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
# Verify quantization
self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
# multimodal
self.image_token_id = getattr(
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
@staticmethod
def from_server_args(
server_args: ServerArgs,
model_path: str = None,
model_revision: str = None,
**kwargs,
):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=model_revision or server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
)
def get_total_num_attention_heads(self) -> int: def get_total_num_attention_heads(self) -> int:
return self.num_attention_heads return self.num_attention_heads
...@@ -454,31 +482,13 @@ class ModelConfig: ...@@ -454,31 +482,13 @@ class ModelConfig:
from huggingface_hub import HfApi from huggingface_hub import HfApi
hf_api = HfApi() hf_api = HfApi()
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
def check_hf_quant_config():
return hf_api.file_exists(
self.model_path, "hf_quant_config.json"
)
# Retry HF API call up to 3 times
file_exists = retry(
check_hf_quant_config,
max_retry=2,
initial_delay=1.0,
max_delay=5.0,
)
if file_exists:
quant_cfg = modelopt_quant_config quant_cfg = modelopt_quant_config
except huggingface_hub.errors.OfflineModeIsEnabled: except huggingface_hub.errors.OfflineModeIsEnabled:
logger.warning( logger.warning(
"Offline mode is enabled, skipping hf_quant_config.json check" "Offline mode is enabled, skipping hf_quant_config.json check"
) )
except Exception as e: pass
logger.warning(
f"Failed to check hf_quant_config.json: {self.model_path} {e}"
)
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")): elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_config_file = os.path.join( quant_config_file = os.path.join(
...@@ -606,7 +616,7 @@ class ModelConfig: ...@@ -606,7 +616,7 @@ class ModelConfig:
"sparse_attention_enabled" "sparse_attention_enabled"
] = True ] = True
def _get_hf_eos_token_id(self) -> Optional[Set[int]]: def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None) eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids is not None: if eos_ids is not None:
# it can be either int or list of int # it can be either int or list of int
...@@ -626,7 +636,7 @@ class ModelConfig: ...@@ -626,7 +636,7 @@ class ModelConfig:
eos_ids = eos_ids | generation_eos_ids eos_ids = eos_ids | generation_eos_ids
return eos_ids return eos_ids
def _maybe_pull_model_tokenizer_from_remote(self) -> None: def maybe_pull_model_tokenizer_from_remote(self) -> None:
""" """
Pull the model config files to a temporary Pull the model config files to a temporary
directory in case of remote. directory in case of remote.
...@@ -769,8 +779,6 @@ multimodal_model_archs = [ ...@@ -769,8 +779,6 @@ multimodal_model_archs = [
"Qwen2AudioForConditionalGeneration", "Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration", "Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"InternVLChatModel", "InternVLChatModel",
"InternS1ForConditionalGeneration", "InternS1ForConditionalGeneration",
......
from typing import Optional, Union
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class Qwen3VLVisionConfig(PretrainedConfig):
model_type = "qwen3_vl"
base_config_key = "vision_config"
def __init__(
self,
depth=27,
hidden_size=1152,
hidden_act="gelu_pytorch_tanh",
intermediate_size=4304,
num_heads=16,
in_channels=3,
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
out_hidden_size=3584,
num_position_embeddings=2304,
deepstack_visual_indexes=[8, 16, 24],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
self.deepstack_visual_indexes = deepstack_visual_indexes
class Qwen3VLTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen3VLModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
head_dim (`int`, *optional*, defaults to 128):
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 5000000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig
>>> # Initializing a Qwen3VL style configuration
>>> configuration = Qwen3VLTextConfig()
>>> # Initializing a model from the Qwen3-VL-7B style configuration
>>> model = Qwen3VLTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_vl_text"
base_config_key = "text_config"
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=128000,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=5000000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class Qwen3VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
vision_start_token_id (`int`, *optional*, defaults to 151652):
The start token index to encode the image prompt.
vision_end_token_id (`int`, *optional*, defaults to 151653):
The end token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings.
```python
>>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig
>>> # Initializing a Qwen3-VL style configuration
>>> configuration = Qwen3VLConfig()
>>> # Initializing a model from the Qwen3-VL-4B style configuration
>>> model = Qwen3VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_vl"
sub_configs = {
"vision_config": Qwen3VLVisionConfig,
"text_config": Qwen3VLTextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
class Qwen3VLMoeTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2MoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 5000000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 60):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
head_dim (`int`, *optional*):
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
```python
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
>>> # Initializing a Qwen3VLMoe style configuration
>>> configuration = Qwen3VLMoeConfig()
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_vl_moe_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3VLMoe`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
hidden_act="silu",
max_position_embeddings=128000,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=5000000.0,
attention_bias=False,
attention_dropout=0.0,
decoder_sparse_step=1,
moe_intermediate_size=1408,
num_experts_per_tok=4,
num_experts=60,
norm_topk_prob=True,
mlp_only_layers=None,
rope_scaling=None,
head_dim=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
self.head_dim = head_dim or hidden_size // num_attention_heads
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class Qwen3VLMoeVisionConfig(PretrainedConfig):
model_type = "qwen3_vl_moe"
base_config_key = "vision_config"
def __init__(
self,
depth=27,
hidden_size=1152,
hidden_act="gelu_pytorch_tanh",
intermediate_size=4304,
num_heads=16,
in_channels=3,
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
out_hidden_size=3584,
num_position_embeddings=2304,
deepstack_visual_indexes=[8, 16, 24],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
self.deepstack_visual_indexes = deepstack_visual_indexes
class Qwen3VLMoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
vision_start_token_id (`int`, *optional*, defaults to 151652):
The start token index to encode the image prompt.
vision_end_token_id (`int`, *optional*, defaults to 151653):
The end token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings.
```python
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
>>> # Initializing a Qwen3-VL-MOE style configuration
>>> configuration = Qwen3VLMoeConfig()
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_vl_moe"
sub_configs = {
"vision_config": Qwen3VLMoeVisionConfig,
"text_config": Qwen3VLMoeTextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
__all__ = [
"Qwen3VLMoeConfig",
"Qwen3VLMoeVisionConfig",
"Qwen3VLConfig",
"Qwen3VLVisionConfig",
]
...@@ -2,9 +2,19 @@ import logging ...@@ -2,9 +2,19 @@ import logging
import os import os
from typing import List, Optional from typing import List, Optional
import torch
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
try:
from mf_adapter import TransferEngine
import_error = None
except ImportError as e:
import_error = e
pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine): ...@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
def __init__( def __init__(
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
): ):
try: if import_error is not None:
from mf_adapter import TransferEngine logger.warning(
except ImportError as e:
raise ImportError(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md" "Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
) from e )
raise import_error
self.engine = TransferEngine() self.engine = TransferEngine()
self.hostname = hostname self.hostname = hostname
...@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine): ...@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
self.initialize() self.initialize()
def initialize(self) -> None: def initialize(self) -> None:
from sglang.srt.layers.dp_attention import (
get_tensor_model_parallel_world_size,
get_tp_group,
)
transfer_protocol = self._get_transfer_protocol()
if transfer_protocol is None or transfer_protocol == "sdma":
trans_op_type = TransferEngine.TransDataOpType.SDMA
else:
trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
"""with device RDMA for PD transfer"""
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(get_tensor_model_parallel_world_size())
]
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
torch.distributed.all_gather(
output_tensor_list, tmp_tensor, group=get_tp_group().device_group
)
"""Initialize the ascend transfer instance.""" """Initialize the ascend transfer instance."""
ret_value = self.engine.initialize( ret_value = self.engine.initialize(
self.store_url, self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
self.session_id,
self.role,
self.npu_id,
) )
if ret_value != 0: if ret_value != 0:
logger.error("Ascend Transfer Engine initialization failed.") logger.error("Ascend Transfer Engine initialization failed.")
...@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine): ...@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
ret_value = -1 ret_value = -1
if ret_value != 0: if ret_value != 0:
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.") logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
@staticmethod
def _get_transfer_protocol():
protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
allowed_protocols = {"device_rdma", "sdma"}
if protocol and protocol.lower() in allowed_protocols:
return protocol.lower()
else:
logger.warning(
"Invalid or no transfer protocol specified, using default protocol."
)
return None
\ No newline at end of file
...@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager): ...@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
def _bind_server_socket(self): def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port)) self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
def _register_to_bootstrap(self): def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST.""" """Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr: if self.dist_init_addr:
...@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager): ...@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
socket.connect(endpoint) socket.connect(endpoint)
return socket return socket
def get_mha_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], List[int], List[int], int]:
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
num_kv_layers = len(src_kv_ptrs) // 2
end_layer = start_layer + num_kv_layers
dst_num_total_layers = len(dst_kv_ptrs) // 2
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
layers_current_pp_stage = len(src_k_ptrs)
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
def get_mla_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], int]:
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(src_kv_ptrs)
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
layers_current_pp_stage = len(src_kv_ptrs)
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
class CommonKVSender(BaseKVSender): class CommonKVSender(BaseKVSender):
......
...@@ -609,21 +609,15 @@ class DecodeTransferQueue: ...@@ -609,21 +609,15 @@ class DecodeTransferQueue:
idx = decode_req.metadata_buffer_index idx = decode_req.metadata_buffer_index
( (
output_id, output_id,
cached_tokens,
output_token_logprobs_val, output_token_logprobs_val,
output_token_logprobs_idx, output_token_logprobs_idx,
output_top_logprobs_val, output_top_logprobs_val,
output_top_logprobs_idx, output_top_logprobs_idx,
output_topk_p,
output_topk_index,
output_hidden_states, output_hidden_states,
) = self.metadata_buffers.get_buf(idx) ) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item()) decode_req.req.output_ids.append(output_id[0].item())
decode_req.req.cached_tokens = cached_tokens[0].item()
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
decode_req.req.output_topk_p = output_topk_p
decode_req.req.output_topk_index = output_topk_index
decode_req.req.hidden_states_tensor = output_hidden_states decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob: if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append( decode_req.req.output_token_logprobs_val.append(
...@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
elif prepare_mlp_sync_flag: elif prepare_mlp_sync_flag:
batch, _ = self._prepare_idle_batch_and_run(None) batch, _ = self._prepare_idle_batch_and_run(None)
queue_size = ( if batch is None and (
len(self.waiting_queue) len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue) + len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue) + len(self.disagg_decode_prealloc_queue.queue)
) == 0
if self.server_args.disaggregation_decode_enable_offload_kvcache: ):
queue_size += len(self.decode_offload_manager.ongoing_offload)
if batch is None and queue_size == 0:
self.self_check_during_idle() self.self_check_during_idle()
self.last_batch = batch self.last_batch = batch
...@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
) )
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
queue_size = ( if batch is None and (
len(self.waiting_queue) len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue) + len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue) + len(self.disagg_decode_prealloc_queue.queue)
) == 0
if self.server_args.disaggregation_decode_enable_offload_kvcache: ):
queue_size += len(self.decode_offload_manager.ongoing_offload)
if batch is None and queue_size == 0:
self.self_check_during_idle() self.self_check_during_idle()
self.last_batch = batch self.last_batch = batch
...@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
self.disagg_decode_transfer_queue.pop_transferred() self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived ) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs) self.waiting_queue.extend(alloc_reqs)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
self.decode_offload_manager.check_offload_progress()
import logging
import threading
import time
import torch
from sglang.srt.server_args import ServerArgs
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
)
logger = logging.getLogger(__name__)
class DecodeKVCacheOffloadManager:
"""Manage decode-side KV cache offloading lifecycle and operations."""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tp_group: torch.distributed.ProcessGroup,
tree_cache: BasePrefixCache,
server_args: ServerArgs,
) -> None:
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = server_args.page_size
self.server_args = server_args
self.request_counter = 0
self.tree_cache = tree_cache
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
if isinstance(kv_cache, MHATokenToKVPool):
self.decode_host_mem_pool = MHATokenToKVPoolHost(
kv_cache,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
)
elif isinstance(kv_cache, MLATokenToKVPool):
self.decode_host_mem_pool = MLATokenToKVPoolHost(
kv_cache,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
)
else:
raise ValueError("Unsupported KV cache type for decode offload")
self.tp_group = tp_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
mem_pool_host=self.decode_host_mem_pool,
page_size=self.page_size,
tp_group=tp_group,
io_backend=server_args.hicache_io_backend,
load_cache_event=threading.Event(),
storage_backend=server_args.hicache_storage_backend,
model_name=server_args.served_model_name,
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
)
self.ongoing_offload = {}
self.ongoing_backup = {}
logger.info("Enable offload kv cache for decode side")
def offload_kv_cache(self, req) -> bool:
"""Offload a finished request's KV cache to storage."""
if self.cache_controller is None or self.decode_host_mem_pool is None:
return False
if req.req_pool_idx == -1:
return False
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
if token_indices.dim() == 0 or token_indices.numel() == 0:
logger.debug(
f"Request {req.rid} has invalid token_indices: {token_indices}"
)
return False
tokens = req.origin_input_ids + req.output_ids
aligned_len = (len(tokens) // self.page_size) * self.page_size
if aligned_len == 0:
return False
token_indices = token_indices[:aligned_len]
tokens = tokens[:aligned_len]
# Asynchronously offload KV cache from device to host by cache controller
self.request_counter += 1
ack_id = self.request_counter
host_indices = self.cache_controller.write(
device_indices=token_indices.long(),
node_id=ack_id,
)
if host_indices is None:
logger.error(f"Not enough host memory for request {req.rid}")
return False
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
return True
def check_offload_progress(self):
"""Check the progress of offload from device to host and backup from host to storage."""
cc = self.cache_controller
qsizes = torch.tensor(
[
len(cc.ack_write_queue),
cc.ack_backup_queue.qsize(),
],
dtype=torch.int,
)
if self.tp_world_size > 1:
torch.distributed.all_reduce(
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
)
n_write, n_backup = map(int, qsizes.tolist())
self._check_offload_progress(n_write)
self._check_backup_progress(n_backup)
def _check_offload_progress(self, finish_count):
"""Check the progress of offload from device to host."""
while finish_count > 0:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize()
for ack_id in ack_list:
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
# Release device
self.tree_cache.cache_finished_req(req)
# Trigger async backup from host to storage by cache controller
self._trigger_backup(req.rid, host_indices, tokens, start_time)
finish_count -= 1
def _check_backup_progress(self, finish_count):
"""Check the progress of backup from host to storage."""
for _ in range(finish_count):
storage_operation = self.cache_controller.ack_backup_queue.get()
ack_id = storage_operation.id
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
# Release host memory
self.decode_host_mem_pool.free(host_indices)
logger.debug(
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
)
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
"""Trigger async backup from host to storage by cache controller."""
# Generate page hashes and write to storage
page_hashes = self._compute_prefix_hash(tokens)
ack_id = self.cache_controller.write_storage(
host_indices,
tokens,
hash_value=page_hashes,
)
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
def _compute_prefix_hash(self, tokens):
last_hash = ""
page_hashes = []
for offset in range(0, len(tokens), self.page_size):
page_tokens = tokens[offset : offset + self.page_size]
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
page_hashes.append(last_hash)
return page_hashes
...@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
self.output_ids = torch.tensor(self.output_ids, device=self.device) self.output_ids = torch.tensor(self.output_ids, device=self.device)
# Simulate the eagle run. # Simulate the eagle run. We add mock data to hidden states for the
if self.spec_algorithm.is_eagle(): # ease of implementation now meaning the first token will have acc rate
# of 0.
if not self.spec_algorithm.is_none():
b = len(self.reqs) b = len(self.reqs)
topk = server_args.speculative_eagle_topk topk_p = torch.arange(
topk_p = torch.stack( b * server_args.speculative_eagle_topk,
[ 0,
torch.as_tensor( -1,
req.output_topk_p[:topk],
device=self.device, device=self.device,
dtype=torch.float32, dtype=torch.float32,
) )
for req in self.reqs topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
], topk_p /= b * server_args.speculative_eagle_topk
dim=0, topk_index = torch.arange(
) b * server_args.speculative_eagle_topk, device=self.device
topk_index = torch.stack(
[
torch.as_tensor(
req.output_topk_index[:topk],
device=self.device,
dtype=torch.int64,
)
for req in self.reqs
],
dim=0,
) )
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
hidden_states_list = [req.hidden_states_tensor for req in self.reqs] hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device) hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
......
...@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager): ...@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
layers_params = None layers_params = None
# pp is not supported on the decode side yet # pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
if self.is_mla_backend: if self.is_mla_backend:
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( src_kv_ptrs = self.kv_args.kv_data_ptrs
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) layers_per_pp_stage = len(src_kv_ptrs)
) dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [ layers_params = [
( (
...@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager): ...@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
dst_kv_ptrs[layer_id], dst_kv_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_current_pp_stage) for layer_id in range(layers_per_pp_stage)
] ]
else: else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) dst_num_total_layers = num_kv_layers * self.pp_size
) src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [ layers_params = [
( (
...@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager): ...@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
dst_k_ptrs[layer_id], dst_k_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_current_pp_stage) for layer_id in range(layers_per_pp_stage)
] + [ ] + [
( (
src_v_ptrs[layer_id], src_v_ptrs[layer_id],
dst_v_ptrs[layer_id], dst_v_ptrs[layer_id],
kv_item_len, kv_item_len,
) )
for layer_id in range(layers_current_pp_stage) for layer_id in range(layers_per_pp_stage)
] ]
assert layers_params is not None assert layers_params is not None
...@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager): ...@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
num_heads_to_send = dst_heads_per_rank num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0 dst_head_start_offset = 0
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( # pp is not supported on the decode side yet
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
) dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token # Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
...@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset, dst_head_slice_offset,
heads_bytes_per_token_to_send, heads_bytes_per_token_to_send,
) )
for layer_id in range(layers_current_pp_stage) for layer_id in range(layers_per_pp_stage)
] + [ ] + [
( (
src_v_ptrs[layer_id], src_v_ptrs[layer_id],
...@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset, dst_head_slice_offset,
heads_bytes_per_token_to_send, heads_bytes_per_token_to_send,
) )
for layer_id in range(layers_current_pp_stage) for layer_id in range(layers_per_pp_stage)
] ]
def process_layer_tp_aware(layer_params): def process_layer_tp_aware(layer_params):
......
...@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index = ( last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1 hidden_state_offset + extend_input_len_per_req[i] - 1
) )
req.output_topk_p = batch.spec_info.topk_p[i]
req.output_topk_index = batch.spec_info.topk_index[i]
if self.spec_algorithm.is_eagle3(): if self.spec_algorithm.is_eagle3():
req.hidden_states_tensor = ( req.hidden_states_tensor = (
batch.spec_info.hidden_states[i].cpu().clone() batch.spec_info.hidden_states[i].cpu().clone()
......
...@@ -85,7 +85,7 @@ class MetadataBuffers: ...@@ -85,7 +85,7 @@ class MetadataBuffers:
self, self,
size: int, size: int,
hidden_size: int, hidden_size: int,
hidden_states_dtype: torch.dtype, dtype: torch.dtype,
max_top_logprobs_num: int = 128, max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None, custom_mem_pool: torch.cuda.MemPool = None,
): ):
...@@ -107,9 +107,7 @@ class MetadataBuffers: ...@@ -107,9 +107,7 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode # We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.cached_tokens = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_token_logprobs_val = torch.zeros( self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device (size, 16), dtype=torch.float32, device=device
) )
...@@ -122,49 +120,33 @@ class MetadataBuffers: ...@@ -122,49 +120,33 @@ class MetadataBuffers:
self.output_top_logprobs_idx = torch.zeros( self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device (size, max_top_logprobs_num), dtype=torch.int32, device=device
) )
# For PD + spec decode
self.output_topk_p = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_topk_index = torch.zeros(
(size, 16), dtype=torch.int64, device=device
)
self.output_hidden_states = torch.zeros( self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=hidden_states_dtype, device=device (size, hidden_size), dtype=dtype, device=device
) )
def get_buf_infos(self): def get_buf_infos(self):
ptrs = [ ptrs = [
self.output_ids.data_ptr(), self.output_ids.data_ptr(),
self.cached_tokens.data_ptr(),
self.output_token_logprobs_val.data_ptr(), self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(), self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(), self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(), self.output_top_logprobs_idx.data_ptr(),
self.output_topk_p.data_ptr(),
self.output_topk_index.data_ptr(),
self.output_hidden_states.data_ptr(), self.output_hidden_states.data_ptr(),
] ]
data_lens = [ data_lens = [
self.output_ids.nbytes, self.output_ids.nbytes,
self.cached_tokens.nbytes,
self.output_token_logprobs_val.nbytes, self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes, self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes, self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes, self.output_top_logprobs_idx.nbytes,
self.output_topk_p.nbytes,
self.output_topk_index.nbytes,
self.output_hidden_states.nbytes, self.output_hidden_states.nbytes,
] ]
item_lens = [ item_lens = [
self.output_ids[0].nbytes, self.output_ids[0].nbytes,
self.cached_tokens[0].nbytes,
self.output_token_logprobs_val[0].nbytes, self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes, self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes, self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes, self.output_top_logprobs_idx[0].nbytes,
self.output_topk_p[0].nbytes,
self.output_topk_index[0].nbytes,
self.output_hidden_states[0].nbytes, self.output_hidden_states[0].nbytes,
] ]
return ptrs, data_lens, item_lens return ptrs, data_lens, item_lens
...@@ -172,20 +154,16 @@ class MetadataBuffers: ...@@ -172,20 +154,16 @@ class MetadataBuffers:
def get_buf(self, idx: int): def get_buf(self, idx: int):
return ( return (
self.output_ids[idx], self.output_ids[idx],
self.cached_tokens[idx],
self.output_token_logprobs_val[idx], self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx], self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx], self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx], self.output_top_logprobs_idx[idx],
self.output_topk_p[idx],
self.output_topk_index[idx],
self.output_hidden_states[idx], self.output_hidden_states[idx],
) )
def set_buf(self, req: Req): def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
if req.return_logprob: if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
...@@ -208,17 +186,8 @@ class MetadataBuffers: ...@@ -208,17 +186,8 @@ class MetadataBuffers:
] = torch.tensor( ] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
) )
# For PD + spec decode # for PD + spec decode
if req.hidden_states_tensor is not None: if req.hidden_states_tensor is not None:
# speculative_eagle_topk should not be greater than 16 currently
topk = req.output_topk_p.size(0)
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
req.output_topk_p
)
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
req.output_topk_index
)
self.output_hidden_states[req.metadata_buffer_index].copy_( self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor req.hidden_states_tensor
) )
......
...@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"): if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version( assert_pkg_version(
"sgl-kernel", "sgl-kernel",
"0.3.12", "0.3.11",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment