Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
......@@ -30,11 +30,20 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
......
......@@ -35,6 +35,15 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
......
......@@ -397,6 +397,15 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
E = config.num_local_experts
......
......@@ -210,8 +210,7 @@
"response = requests.post(url, json=data)\n",
"print_highlight(response.text)\n",
"assert response.json()[\"success\"] is True\n",
"assert response.json()[\"message\"] == \"Succeeded to update model weights.\"\n",
"assert response.json().keys() == {\"success\", \"message\"}"
"assert response.json()[\"message\"] == \"Succeeded to update model weights.\""
]
},
{
......@@ -411,7 +410,7 @@
" },\n",
")\n",
"output = response.json()\n",
"output_tokens = output[\"token_ids\"]\n",
"output_tokens = output[\"output_ids\"]\n",
"\n",
"output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n",
"print_highlight(f\"Tokenized Output: {output_tokens}\")\n",
......
......@@ -96,7 +96,6 @@ Please consult the documentation below to learn more about the parameters you ma
* `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine.
* `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance.
* `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU.
* `prefill_only_one_req`: When this flag is turned on, the engine prefills only one request at a time.
## Other runtime options
......
......@@ -96,7 +96,10 @@ dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.package-data]
"sglang" = ["srt/layers/moe/fused_moe_triton/configs/*.json", "srt/layers/quantization/configs/*.json"]
"sglang" = [
"srt/layers/moe/fused_moe_triton/configs/*.json",
"srt/layers/quantization/configs/*.json",
]
[tool.setuptools.packages.find]
exclude = [
......
......@@ -8,8 +8,10 @@
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
- `bench_serving.py`: Benchmark online serving with dynamic requests.
- `check_env.py`: Check the environment variables.
- `check_env.py`: Check the environment variables and dependencies.
- `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server.
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
- `profiler.py`: Profile a running server.
- `utils.py`: Common utilities.
- `version.py`: Version info.
......@@ -56,6 +56,7 @@ class BenchArgs:
profile: bool = False
skip_warmup: bool = False
do_not_exit: bool = False
prompt_suffix: str = ""
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
......@@ -177,6 +178,12 @@ class BenchArgs:
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......@@ -216,6 +223,10 @@ def throughput_test_once(
]
if profile:
assert (
"SGLANG_TORCH_PROFILER_DIR" in os.environ
), "Please set SGLANG_TORCH_PROFILER_DIR."
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
backend.start_profile()
st = time.perf_counter()
......@@ -229,6 +240,8 @@ def throughput_test_once(
if backend_name == "runtime":
gen_out = json.loads(gen_out)
server_info = backend.get_server_info()
measurement_results["total_latency"] = latency
measurement_results["total_output_tokens"] = sum(
o["meta_info"]["completion_tokens"] for o in gen_out
......@@ -246,6 +259,7 @@ def throughput_test_once(
measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"]
) / latency
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
return measurement_results
......@@ -361,6 +375,11 @@ def throughput_test(
print(
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
)
print(
"{:<40} {:<10.2f}".format(
"Last generation throughput (tok/s):", result["last_gen_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", result["request_throughput"]
......
......@@ -8,7 +8,6 @@ Usage:
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
"""
import argparse
......@@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
def remove_suffix(text: str, suffix: str) -> str:
return text[: -len(suffix)] if text.endswith(suffix) else text
def get_auth_headers() -> Dict[str, str]:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
......@@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]:
return {}
# trt llm not support ignore_eos
# trt llm does not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
......@@ -179,6 +182,7 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
......@@ -215,11 +219,14 @@ async def async_request_openai_completions(
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
output_len = data.get("usage", {}).get(
"completion_tokens", output_len
)
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
......@@ -339,9 +346,11 @@ async def async_request_sglang_generate(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
last_output_len = 0
try:
async with session.post(
url=api_url, json=payload, headers=headers
......@@ -365,6 +374,9 @@ async def async_request_sglang_generate(
# want to check a token was generated
if data["text"]:
timestamp = time.perf_counter()
generated_text = data["text"]
output_len = data["meta_info"]["completion_tokens"]
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
......@@ -372,7 +384,13 @@ async def async_request_sglang_generate(
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
num_new_tokens = output_len - last_output_len
if num_new_tokens == 0:
continue
adjust_itl = (
timestamp - most_recent_timestamp
) / num_new_tokens
output.itl.extend([adjust_itl] * num_new_tokens)
most_recent_timestamp = timestamp
generated_text = data["text"]
......@@ -380,7 +398,7 @@ async def async_request_sglang_generate(
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
......@@ -388,6 +406,7 @@ async def async_request_sglang_generate(
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
print(f"{output.error=}")
if pbar:
pbar.update(1)
......@@ -461,6 +480,7 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
prompt_suffix=args.prompt_suffix,
apply_chat_template=args.apply_chat_template,
)
elif args.dataset_name == "random":
......@@ -521,7 +541,9 @@ class BenchmarkMetrics:
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
p95_itl_ms: float
p99_itl_ms: float
max_itl_ms: float
mean_e2e_latency_ms: float
median_e2e_latency_ms: float
std_e2e_latency_ms: float
......@@ -572,6 +594,7 @@ def sample_sharegpt_requests(
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
prompt_suffix: Optional[str] = "",
apply_chat_template=False,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
......@@ -584,11 +607,19 @@ def sample_sharegpt_requests(
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
(
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
for data in dataset
]
......@@ -603,6 +634,8 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
prompt = dataset[i][0]
if prompt_suffix:
prompt = prompt
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
......@@ -666,10 +699,17 @@ def sample_random_requests(
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
(
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
for data in dataset
]
# Shuffle the dataset.
......@@ -895,7 +935,9 @@ def calculate_metrics(
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
max_itl_ms=np.max(itls or 0) * 1000,
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
......@@ -919,6 +961,7 @@ async def benchmark(
lora_name: str,
extra_request_body: Dict[str, Any],
profile: bool,
pd_seperated: bool = False,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
......@@ -1004,6 +1047,17 @@ async def benchmark(
if pbar is not None:
pbar.close()
if "sglang" in backend:
server_info = requests.get(base_url + "/get_server_info")
if pd_seperated:
accept_length = server_info.json()["decode"][0].get(
"avg_spec_accept_length", None
)
else:
accept_length = server_info.json().get("avg_spec_accept_length", None)
else:
accept_length = None
# Compute metrics and print results
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, output_lens = calculate_metrics(
......@@ -1053,6 +1107,8 @@ async def benchmark(
)
)
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
if accept_length:
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
......@@ -1066,16 +1122,12 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print(
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
)
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
print("=" * 50)
if (
......@@ -1117,8 +1169,10 @@ async def benchmark(
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p95_itl_ms": metrics.p95_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"concurrency": metrics.concurrency,
"accept_length": accept_length,
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
......@@ -1151,14 +1205,6 @@ async def benchmark(
return result
def parse_request_rate_range(request_rate_range):
if len(request_rate_range.split(",")) == 3:
start, stop, step = map(int, request_rate_range.split(","))
return list(range(start, stop, step))
else:
return list(map(int, request_rate_range.split(",")))
def check_chat_template(model_path):
try:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
......@@ -1168,6 +1214,12 @@ def check_chat_template(model_path):
return False
def set_global_args(args_: argparse.Namespace):
"""Set the global args."""
global args
args = args_
def run_benchmark(args_: argparse.Namespace):
global args
args = args_
......@@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "max_concurrency"):
args.max_concurrency = None
print(f"benchmark_args={args}")
# Set global environments
set_ulimit()
random.seed(args.seed)
......@@ -1272,49 +1326,26 @@ def run_benchmark(args_: argparse.Namespace):
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer = get_tokenizer(tokenizer_id)
input_requests = get_dataset(args, tokenizer)
if not args.multi:
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
pd_seperated=args.pd_seperated,
)
else:
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
request_rates = parse_request_rate_range(args.request_rate_range)
for rate in request_rates:
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
)
)
def set_ulimit(target_soft_limit=65535):
......@@ -1428,17 +1459,6 @@ if __name__ == "__main__":
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--multi",
action="store_true",
help="Use request rate range rather than single value.",
)
parser.add_argument(
"--request-rate-range",
type=str,
default="2,34,2",
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
parser.add_argument(
"--disable-tqdm",
......@@ -1485,6 +1505,17 @@ if __name__ == "__main__":
default=None,
help="The name of LoRA adapter",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
parser.add_argument(
"--pd-seperated",
action="store_true",
help="Benchmark PD disaggregation server",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(
......
......@@ -34,11 +34,9 @@ class GlobalConfig:
self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True
# Interpreter optimization configs
# Language frontend interpreter optimization configs
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True
self.enable_flashinfer_mla = False
global_config = GlobalConfig()
......@@ -329,7 +329,12 @@ class RuntimeEndpoint(BaseBackend):
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)
try:
return sum(values) / len(values)
except TypeError:
print(f"{input_logprobs=}", flush=True)
print(f"{input_logprobs[0]=}", flush=True)
exit(-1)
class Runtime:
......
......@@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
@dataclass
......@@ -42,13 +43,15 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
......
......@@ -44,6 +44,7 @@ class ModelConfig:
is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
) -> None:
self.model_path = model_path
self.revision = revision
......@@ -51,11 +52,16 @@ class ModelConfig:
# Parse args
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
......@@ -64,6 +70,9 @@ class ModelConfig:
self.hf_config.architectures, is_embedding
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
self.is_audio_model = is_audio_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
......@@ -71,7 +80,9 @@ class ModelConfig:
derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None:
if context_length > derived_context_len:
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"):
if get_bool_env_var(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
):
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
......@@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "Grok1VForCausalLM" in model_architectures
or "Grok1AForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
......@@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]):
return False
def is_multimodal_gen_model(model_architectures: List[str]):
return False
def is_image_gen_model(model_architectures: List[str]):
return False
def is_audio_model(model_architectures: List[str]):
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
......
......@@ -15,7 +15,7 @@
import json
import logging
from typing import List, Tuple
from typing import List, Optional, Tuple, Union
import torch
from xgrammar import (
......@@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200
class XGrammarGrammar(BaseGrammarObject):
def __init__(
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
self,
matcher: GrammarMatcher,
vocab_size: int,
ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]],
) -> None:
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.override_stop_tokens = override_stop_tokens
self.finished = False
def accept_token(self, token: int):
......@@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self):
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
matcher = GrammarMatcher(
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
)
class XGrammarGrammarBackend(BaseGrammarBackend):
......@@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
......@@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
raise ValueError(f"Invalid key_type: {key_type}")
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, ctx)
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
def reset(self):
if self.grammar_compiler:
......
......@@ -121,6 +121,7 @@ class Engine:
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
......@@ -142,6 +143,7 @@ class Engine:
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
......@@ -179,6 +181,7 @@ class Engine:
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
stream: bool = False,
......@@ -195,6 +198,7 @@ class Engine:
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
stream=stream,
custom_logit_processor=custom_logit_processor,
......@@ -226,15 +230,22 @@ class Engine:
kill_process_tree(os.getpid(), include_parent=False)
def start_profile(self):
self.tokenizer_manager.start_profile()
loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.start_profile())
def stop_profile(self):
self.tokenizer_manager.stop_profile()
def get_server_info(self):
loop = asyncio.get_event_loop()
internal_states = loop.run_until_complete(
self.tokenizer_manager.get_internal_state()
)
return {
**dataclasses.asdict(self.tokenizer_manager.server_args), # server args
**dataclasses.asdict(self.tokenizer_manager.server_args),
**self.scheduler_info,
**internal_states,
"version": __version__,
}
......@@ -323,6 +334,7 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
# Set prometheus env vars
if server_args.enable_metrics:
......@@ -346,12 +358,23 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.",
)
def sigchld_handler(signum, frame):
pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0:
logger.warning(
"Child process unexpectedly failed with an exit code %d. pid=%d",
exitcode,
pid,
)
signal.signal(signal.SIGCHLD, sigchld_handler)
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def sigquit_handler(signum, frame):
logger.error(
"Received sigquit from a child proces. It usually means the child failed."
"Received sigquit from a child process. It usually means the child failed."
)
kill_process_tree(os.getpid())
......
......@@ -25,11 +25,14 @@ import os
import threading
import time
from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional
from typing import AsyncIterator, Callable, Dict, Optional
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
from contextlib import asynccontextmanager
import numpy as np
import orjson
import requests
import uvicorn
......@@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
SetInternalStateReq,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
VertexGenerateReqInput,
......@@ -78,22 +83,13 @@ from sglang.srt.utils import (
kill_process_tree,
set_uvicorn_logging_configs,
)
from sglang.srt.warmup import execute_warmups
from sglang.utils import get_exception_traceback
from sglang.version import __version__
logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Fast API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Store global states
@dataclasses.dataclass
......@@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
)
logger.info("Warmup ended")
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
# Fast API
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
##### Native API endpoints #####
......@@ -123,24 +147,48 @@ async def health() -> Response:
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"HEALTH_CHECK_{time.time()}"
if _global_state.tokenizer_manager.is_generation:
if _global_state.tokenizer_manager.is_image_gen:
raise NotImplementedError()
elif _global_state.tokenizer_manager.is_generation:
gri = GenerateReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
rid=rid,
input_ids=[0],
sampling_params=sampling_params,
log_metrics=False,
)
else:
gri = EmbeddingReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
)
try:
async def gen():
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break
return Response(status_code=200)
except Exception as e:
logger.exception(e)
return Response(status_code=503)
tic = time.time()
task = asyncio.create_task(gen())
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1)
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
return Response(status_code=200)
task.cancel()
tic_time = time.strftime("%H:%M:%S", time.localtime(tic))
last_receive_time = time.strftime(
"%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)
)
logger.error(
f"Health check failed. Server couldn't get a response from detokenizer for last "
f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. "
f"last_heartbeat time: {last_receive_time}"
)
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
return Response(status_code=503)
@app.get("/get_model_info")
......@@ -156,13 +204,21 @@ async def get_model_info():
@app.get("/get_server_info")
async def get_server_info():
internal_states = await _global_state.tokenizer_manager.get_internal_state()
return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info,
**internal_states,
"version": __version__,
}
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
return res
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
async def generate_request(obj: GenerateReqInput, request: Request):
......@@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
logger.error(f"Error: {e}")
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
......@@ -236,9 +293,14 @@ async def flush_cache():
@app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async():
async def start_profile_async(obj: Optional[ProfileReqInput] = None):
"""Start profiling."""
_global_state.tokenizer_manager.start_profile()
if obj is None:
obj = ProfileReqInput()
await _global_state.tokenizer_manager.start_profile(
obj.output_dir, obj.num_steps, obj.activities
)
return Response(
content="Start profiling.\n",
status_code=200,
......@@ -257,11 +319,15 @@ async def stop_profile_async():
@app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk in-place without re-launching the server."""
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
obj, request
"""Update the weights from disk inplace without re-launching the server."""
success, message, num_paused_requests = (
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
)
content = {"success": success, "message": message}
content = {
"success": success,
"message": message,
"num_paused_requests": num_paused_requests,
}
if success:
return ORJSONResponse(
content,
......@@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
async def release_memory_occupation(
obj: ReleaseMemoryOccupationReqInput, request: Request
):
"""Release GPU occupation temporarily"""
"""Release GPU memory occupation temporarily."""
try:
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
except Exception as e:
......@@ -334,7 +400,7 @@ async def release_memory_occupation(
async def resume_memory_occupation(
obj: ResumeMemoryOccupationReqInput, request: Request
):
"""Resume GPU occupation"""
"""Resume GPU memory occupation."""
try:
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
except Exception as e:
......@@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session"""
"""Close the session."""
try:
await _global_state.tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
......@@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
@app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
"""Close the session"""
"""Configure the request logging options."""
_global_state.tokenizer_manager.configure_logging(obj)
return Response(status_code=200)
......@@ -511,6 +577,7 @@ def _create_error_response(e):
def launch_server(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
launch_callback: Optional[Callable[[], None]] = None,
):
"""
Launch SRT (SGLang Runtime) Server.
......@@ -544,21 +611,23 @@ def launch_server(
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request
t = threading.Thread(
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
_global_state.tokenizer_manager.image_token_id,
launch_callback,
),
)
t.start()
app.warmup_thread = warmup_thread
try:
# Update logging configs
set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests
uvicorn.run(
app,
......@@ -569,10 +638,15 @@ def launch_server(
loop="uvloop",
)
finally:
t.join()
warmup_thread.join()
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
def _wait_and_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
headers = {}
url = server_args.url()
if server_args.api_key:
......@@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
else:
json_data["text"] = "The capital city of France is"
# Debug dumping
if server_args.debug_tensor_dump_input_file:
json_data.pop("text", None)
json_data["input_ids"] = np.load(
server_args.debug_tensor_dump_input_file
).tolist()
json_data["sampling_params"]["max_new_tokens"] = 0
try:
for _ in range(server_args.dp_size):
for i in range(server_args.dp_size):
res = requests.post(
url + request_name,
json=json_data,
......@@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())
if launch_callback is not None:
launch_callback()
......@@ -60,6 +60,7 @@ class VerlEngine:
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
) -> Dict:
......@@ -76,6 +77,7 @@ class VerlEngine:
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
custom_logit_processor=custom_logit_processor,
)
......
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
class AttentionBackend(ABC):
......@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
......@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
......@@ -64,7 +64,14 @@ class AttentionBackend(ABC):
):
"""Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
return self.forward_decode(
q,
k,
v,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
)
else:
return self.forward_extend(
q,
......@@ -72,7 +79,7 @@ class AttentionBackend(ABC):
v,
layer,
forward_batch,
save_kv_cache,
save_kv_cache=save_kv_cache,
)
def forward_decode(
......
......@@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend):
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__()
......@@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend):
assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf]
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
if kv_last_page_len_buf is None:
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
else:
assert self.num_wrappers == 1
self.kv_last_page_len = kv_last_page_len_buf
self.qo_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
......@@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend:
dtype=torch.int32,
device=model_runner.device,
)
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
......@@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend:
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=self.kv_last_page_len,
)
)
self.max_context_len = self.attn_backends[0].max_context_len
......
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
......@@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
class TritonAttnBackend(AttentionBackend):
......@@ -232,7 +232,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
assert encoder_lens is None, "Not supported"
......@@ -310,7 +310,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
# NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment