Unverified Commit 22352d47 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)


Co-authored-by: default avatarKan Wu <wukanustc@gmail.com>
parent c5131f7a
......@@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--log-level` | The logging level of all loggers. | info |
| `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None |
| `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
| `--log-requests-level` | 0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output. | 0 |
| `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | 0 |
| `--show-time-cost` | Show time cost of custom marks. | False |
| `--enable-metrics` | Enable log prometheus metrics. | False |
| `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None |
......
......@@ -38,6 +38,7 @@ class BenchArgs:
output_len: Tuple[int] = (16,)
temperature: float = 0.0
return_logprob: bool = False
client_stream_interval: int = 1
input_len_step_percentage: float = 0.0
result_filename: str = "result.jsonl"
base_url: str = ""
......@@ -60,6 +61,11 @@ class BenchArgs:
)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--client-stream-interval",
type=int,
default=BenchArgs.client_stream_interval,
)
parser.add_argument(
"--input-len-step-percentage",
type=float,
......@@ -120,6 +126,7 @@ def run_one_case(
output_len: int,
temperature: float,
return_logprob: bool,
stream_interval: int,
input_len_step_percentage: float,
run_name: str,
result_filename: str,
......@@ -168,6 +175,7 @@ def run_one_case(
"max_new_tokens": output_len,
"ignore_eos": True,
"json_schema": json_schema,
"stream_interval": stream_interval,
},
"return_logprob": return_logprob,
"stream": True,
......@@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
else:
proc, base_url = launch_server_process(server_args)
tokenizer_id = server_args.tokenizer_path or server_args.model_path
tokenizer = get_tokenizer(tokenizer_id)
server_info = requests.get(base_url + "/get_server_info")
tokenizer_path = server_info.json()["tokenizer_path"]
tokenizer = get_tokenizer(tokenizer_path)
# warmup
if not bench_args.skip_warmup:
......@@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
output_len=16,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name="",
result_filename="",
......@@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
......@@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
......
......@@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url
else f"http://{args.host}:{args.port}/generate"
)
args.apply_chat_template = True
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = (
f"{args.base_url}/v1/completions"
......
......@@ -147,12 +147,11 @@ class InternLM2Config(PretrainedConfig):
)
if (
rope_scaling_factor is None
or not isinstance(rope_scaling_factor, float)
or not isinstance(rope_scaling_factor, int)
or not isinstance(rope_scaling_factor, (float, int))
or rope_scaling_factor < 1.0
):
raise ValueError(
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor}"
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
)
if isinstance(rope_scaling_factor, int):
rope_scaling_factor = float(rope_scaling_factor)
......
......@@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
......@@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
_global_state.tokenizer_manager
)
server_args: ServerArgs = fast_api_app.server_args
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
server_args.disaggregation_mode,
server_args.warmups.split(","),
_global_state.tokenizer_manager,
)
logger.info("Warmup ended")
......@@ -280,13 +281,17 @@ async def get_model_info():
"model_path": _global_state.tokenizer_manager.model_path,
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": _global_state.tokenizer_manager.is_generation,
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
}
return result
@app.get("/get_server_info")
async def get_server_info():
internal_states = await _global_state.tokenizer_manager.get_internal_state()
# Returns interna states per DP.
internal_states: List[Dict[Any, Any]] = (
await _global_state.tokenizer_manager.get_internal_state()
)
return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info,
......@@ -300,6 +305,8 @@ async def get_load():
return await _global_state.tokenizer_manager.get_load()
# example usage:
# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
......@@ -886,6 +893,15 @@ def launch_server(
add_prometheus_middleware(app)
enable_func_timer()
image_token_text = None
if (
tokenizer_manager.image_token_id is not None
and not server_args.skip_tokenizer_init
):
image_token_text = tokenizer_manager.tokenizer.decode(
[tokenizer_manager.image_token_id]
)
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
......@@ -893,7 +909,7 @@ def launch_server(
args=(
server_args,
pipe_finish_writer,
_global_state.tokenizer_manager.image_token_id,
image_token_text,
launch_callback,
),
)
......@@ -1022,9 +1038,10 @@ def _wait_and_warmup(
return
# Debug print
# logger.info(f"{res.json()=}")
# logger.info(f"warmup request returns: {res.json()=}")
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")
......
......@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip()
fused_softcap_autotune = triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
......@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape
min_num_warps = 16 if _is_hip else 32
if autotune:
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
)
else:
max_warps = 16 if _is_hip else 32
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
),
4,
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
),
}
......@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
else:
output = torch.empty_like(x)
bs, hidden_dim = x.shape
min_num_warps = 16 if _is_hip else 32
max_warps = 16 if _is_hip else 32
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
),
}
......@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
return self.rmsnorm2.forward_native(residual), residual
@triton.jit
def experts_combine_kernel(
out_hidden_states,
moe_hidden_states,
mlp_hidden_states,
combine_k: tl.constexpr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
start_index_mlp = pid * hidden_dim
start_index_rmoe = pid * hidden_dim * combine_k
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
combine_k_offsets = tl.arange(0, combine_k)
moe_x = tl.load(
moe_hidden_states
+ start_index_rmoe
+ combine_k_offsets[:, None] * hidden_dim
+ offsets[None, :],
mask=mask[None, :],
other=0.0,
)
moe_x = tl.sum(moe_x, axis=0)
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
combined_x = (moe_x + mlp_x) / 1.4142135623730951
tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
assert moe_hidden_states.is_contiguous()
assert mlp_hidden_states.is_contiguous()
if len(moe_hidden_states.shape) == 2:
combine_k = 1 # pre-combined
else:
combine_k = moe_hidden_states.shape[1]
if output_buffer is None:
out_hidden_states = torch.empty_like(mlp_hidden_states)
else:
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
mlp_hidden_states.shape
)
bs, hidden_dim = mlp_hidden_states.shape
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
),
}
experts_combine_kernel[(bs,)](
out_hidden_states,
moe_hidden_states,
mlp_hidden_states,
combine_k,
hidden_dim,
**config,
)
return out_hidden_states
# gelu on first half of vector
@triton.jit
def gelu_and_mul_kernel(
......@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
out_scales = scales
static_scale = True
max_warps = 16 if _is_hip else 32
config = {
# 8 ele per thread (not tuned)
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
),
}
......
from typing import Tuple
from typing import Optional, Tuple
import torch
import triton
......@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
moe_router_weight_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
topk_ids_ptr, # output (bs, topk)
correction_bias_ptr,
is_correction_bias: tl.constexpr,
num_experts: tl.constexpr,
topk: tl.constexpr,
moe_softcapping: tl.constexpr,
......@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
bottom = exped + 1
logits_softcapped = top / bottom * moe_softcapping
# Add bias after softcapping
if is_correction_bias:
bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
logits_softcapped = logits_softcapped + bias
# topk
# assert 1 <= topk <= num_experts
......@@ -109,6 +116,7 @@ def fused_moe_router_impl(
router_weight: torch.Tensor,
topk: int,
moe_softcapping: float,
correction_bias: Optional[torch.Tensor] = None,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
......@@ -117,23 +125,23 @@ def fused_moe_router_impl(
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
is_correction_bias = correction_bias is not None
grid = lambda meta: (bs,)
min_num_warps = 16 if _is_hip else 32
max_warps = 16 if _is_hip else 32
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
),
}
fused_moe_router_kernel[grid](
fused_moe_router_kernel[(bs,)](
x,
router_weight,
topk_weights,
topk_ids,
correction_bias,
is_correction_bias=is_correction_bias,
num_experts=num_experts,
topk=topk,
moe_softcapping=moe_softcapping,
......@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
topk_ids_ptr, # output (bs, topk)
bs,
num_experts: tl.constexpr,
topk: tl.constexpr, # only support topk == 1
topk: tl.constexpr, # only support topk <= 2
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
K: tl.constexpr,
......@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# 5. top1
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
cond_top1 = arange_block_size_n < num_experts
top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
top1_v = tl.max(
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
)
invsumexp = 1.0 / tl.sum(
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
top1_invsumexp = 1.0 / tl.sum(
tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
)
# 6. store to output
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
topk_mask = offs_topk < bs
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
# 6. store top1 to output
offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
top1_mask = offs_top1 < bs * topk
tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
tl.store(
topk_weights_ptr + offs_topk,
invsumexp,
mask=topk_mask,
topk_weights_ptr + offs_top1,
top1_invsumexp,
mask=top1_mask,
)
# 7. handle topk == 2
if topk == 2:
cond_top2 = (arange_block_size_n < num_experts) and (
arange_block_size_n != top1[:, None]
)
top2 = tl.argmax(
tl.where(cond_top2, logits_softcapped, float("-inf")),
axis=1,
keep_dims=True,
)
top2_v = tl.sum(
logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
)
top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
# store top2
offs_top2 = (
pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
)
top2_mask = offs_top2 < bs * topk
tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
tl.store(
topk_weights_ptr + offs_top2,
top2_invsumexp,
mask=top2_mask,
)
def fused_moe_router_large_bs_impl(
x: torch.Tensor,
......@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
assert num_experts <= BLOCK_SIZE_N
assert hidden_dim % BLOCK_SIZE_K == 0
assert topk == 1
assert topk <= 2
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
......@@ -273,6 +309,7 @@ def fused_moe_router_shim(
gating_output,
topk,
renormalize,
correction_bias: Optional[torch.Tensor] = None,
):
assert not renormalize
assert (
......@@ -286,7 +323,7 @@ def fused_moe_router_shim(
BLOCK_SIZE_K = 256
if (
bs >= 512
and topk == 1
and topk <= 2
and num_experts <= BLOCK_SIZE_N
and hidden_dim % BLOCK_SIZE_K == 0
):
......@@ -305,6 +342,7 @@ def fused_moe_router_shim(
router_weight=gating_output,
topk=topk,
moe_softcapping=moe_softcapping,
correction_bias=correction_bias,
)
......
......@@ -28,7 +28,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument("--log-requests", action="store_true")
parser.add_argument("--log-requests-level", type=int, default=2)
parser.add_argument("--log-requests-level", type=int, default=3)
parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
)
......
......@@ -516,9 +516,6 @@ class EmbeddingReqInput:
# For cross-encoder requests
is_cross_encoder_request: bool = False
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None:
......@@ -572,6 +569,9 @@ class EmbeddingReqInput:
self.rid = uuid.uuid4().hex
return self.rid
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
def __getitem__(self, i):
if self.is_cross_encoder_request:
return EmbeddingReqInput(
......
......@@ -2,12 +2,15 @@
Multi-modality utils
"""
import hashlib
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
......@@ -678,3 +681,52 @@ def get_multimodal_data_bounds(
# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)
......@@ -3,7 +3,6 @@ import importlib
import inspect
import logging
import pkgutil
from functools import lru_cache
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.server_args import ServerArgs
......
......@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import copy
import dataclasses
import hashlib
import logging
import threading
from enum import Enum, auto
......@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
......@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"max_micro_batch_size",
"disable_shared_experts_fusion",
"sampling_backend",
"speculative_accept_threshold_acc",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"torchao_config",
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
......@@ -180,7 +178,9 @@ class Modality(Enum):
@dataclasses.dataclass
class MultimodalDataItem:
"""
A single multimodal data, from a single image/video/audio or others.
One MultimodalDataItem contains all inputs for one modality.
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
One for images and one for audio.
We put the common fields first and the model-specific fields last.
"""
......@@ -232,53 +232,7 @@ class MultimodalDataItem:
"""
Set the pad value after first hashing the data
"""
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x
for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)
from sglang.srt.managers.mm_utils import hash_feature
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
......
......@@ -418,14 +418,16 @@ class Scheduler(
self.last_decode_stats_tic = time.perf_counter()
self.last_prefill_stats_tic = time.perf_counter()
self.return_health_check_ct = 0
self.num_retracted_reqs: int = 0
self.num_paused_reqs: int = 0
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.sessions: Dict[str, Session] = {}
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None
# Init session info
self.sessions: Dict[str, Session] = {}
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable
......@@ -473,26 +475,12 @@ class Scheduler(
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
self.parent_process = psutil.Process().parent()
# Init memory saver, profiler and metric stats
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
# Init profiler
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
# Init metrics stats
self.init_profier()
self.init_metrics()
self.init_kv_events(server_args.kv_events_config)
......@@ -526,6 +514,7 @@ class Scheduler(
]
)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
......@@ -624,6 +613,21 @@ class Scheduler(
)
)
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_metrics(self):
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
......@@ -2107,6 +2111,18 @@ class Scheduler(
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = {
"weight": round(
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
),
"kvcache": round(
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
),
"cuda_graph": round(
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
),
"token_capacity": int(self.max_total_num_tokens),
}
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count
......
......@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
stream_interval = (
req.sampling_params.stream_interval or self.stream_interval
)
should_output = len(req.output_ids) % stream_interval == 0
should_output = (
len(req.output_ids) % stream_interval == 1
if not self.model_config.is_multimodal_gen
and stream_interval > 1
else len(req.output_ids) % stream_interval == 0
)
else:
should_output = (
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
and not self.model_config.is_multimodal_gen
if not self.model_config.is_multimodal_gen
else False
)
if should_output:
......
......@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.multimodal_processor import (
get_dummy_processor,
get_mm_processor,
import_processors,
)
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -187,6 +183,8 @@ class TokenizerManager:
if server_args.preferred_sampling_params
else None
)
self.crash_dump_folder = server_args.crash_dump_folder
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Init inter-process communication
context = zmq.asyncio.Context(2)
......@@ -251,10 +249,11 @@ class TokenizerManager:
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
self.crash_dump_request_list: deque[Tuple] = deque()
self.log_request_metadata = self.get_log_request_metadata()
self.asyncio_tasks = set()
self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None
self.asyncio_tasks = set()
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
......@@ -266,14 +265,14 @@ class TokenizerManager:
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
......@@ -324,7 +323,6 @@ class TokenizerManager:
self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
......@@ -484,7 +482,7 @@ class TokenizerManager:
token_type_ids = encoded.get("token_type_ids", [None])[0]
if self.mm_processor and obj.contains_mm_input():
image_inputs = await self.mm_processor.process_mm_data_async(
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data,
input_text=input_text or input_ids,
request_obj=obj,
......@@ -547,6 +545,14 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature."
)
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
) -> None:
if any(id >= vocab_size for id in input_ids):
raise ValueError(
f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
)
def _create_tokenized_object(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......@@ -1096,12 +1102,36 @@ class TokenizerManager:
"image_data",
"audio_data",
"lora_path",
"sampling_params",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
out_skip_names = set(["text", "output_ids", "embedding"])
elif self.log_requests_level == 1:
max_length = 2048
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
elif self.log_requests_level == 2:
max_length = 2048
elif self.log_requests_level == 3:
max_length = 1 << 30
else:
raise ValueError(
......@@ -1118,6 +1148,8 @@ class TokenizerManager:
self.dump_requests_folder = obj.dump_requests_folder
if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold
if obj.crash_dump_folder is not None:
self.crash_dump_folder = obj.crash_dump_folder
logging.info(f"Config logging: {obj=}")
self.log_request_metadata = self.get_log_request_metadata()
......@@ -1166,6 +1198,52 @@ class TokenizerManager:
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
)
def dump_requests_before_crash(self):
if self.crash_dump_performed:
logger.info(
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
if not self.crash_dump_folder:
return
data_to_dump = []
if self.crash_dump_request_list:
data_to_dump.extend(self.crash_dump_request_list)
# Add unfinished requests from rid_to_state
unfinished_requests = []
for rid, state in self.rid_to_state.items():
if not state.finished:
unfinished_requests.append(
(state.obj, {}, state.created_time, time.time())
)
if unfinished_requests:
data_to_dump.extend(unfinished_requests)
if not data_to_dump:
return
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
# Include server_args in the dump
data_to_dump_with_server_args = {
"server_args": self.server_args,
"requests": data_to_dump,
}
with open(filename, "wb") as f:
pickle.dump(data_to_dump_with_server_args, f)
logger.error(
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
)
async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(5)
......@@ -1175,11 +1253,12 @@ class TokenizerManager:
remain_num_req = len(self.rid_to_state)
if self.health_check_failed:
# if health check failed, exit immediately
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
remain_num_req,
)
self.dump_requests_before_crash()
break
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
......@@ -1196,6 +1275,7 @@ class TokenizerManager:
if remain_num_req > 0:
await asyncio.sleep(5)
else:
self.dump_requests_before_crash()
break
kill_process_tree(os.getpid(), include_parent=True)
......@@ -1273,16 +1353,7 @@ class TokenizerManager:
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchMultimodalOut):
if isinstance(recv_obj.outputs[i], str):
out_dict = {
"text": recv_obj.outputs[i],
"meta_info": meta_info,
}
else:
out_dict = {
"outputs": json.dumps(recv_obj.outputs[i]),
"meta_info": meta_info,
}
raise NotImplementedError("BatchMultimodalOut not implemented")
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
......@@ -1306,6 +1377,8 @@ class TokenizerManager:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
self.dump_requests(state, out_dict)
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
self.record_request_for_crash_dump(state, out_dict)
def convert_logprob_style(
self,
......@@ -1317,6 +1390,9 @@ class TokenizerManager:
recv_obj: BatchStrOut,
recv_obj_index: int,
):
if recv_obj.input_token_logprobs_val is None:
return
if len(recv_obj.input_token_logprobs_val) > 0:
state.input_token_logprobs_val.extend(
recv_obj.input_token_logprobs_val[recv_obj_index]
......@@ -1436,7 +1512,10 @@ class TokenizerManager:
else 0
)
if state.first_token_time == 0.0:
if (
state.first_token_time == 0.0
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
state.first_token_time = state.last_time = time.time()
state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token(
......@@ -1484,14 +1563,31 @@ class TokenizerManager:
to_dump = self.dump_request_list
self.dump_request_list = []
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": to_dump,
}
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump, f)
pickle.dump(to_dump_with_server_args, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
current_time = time.time()
self.crash_dump_request_list.append(
(state.obj, out_dict, state.created_time, current_time)
)
# Remove requests older than 5 minutes based on finish time
while (
self.crash_dump_request_list
and current_time - self.crash_dump_request_list[0][3] >= 300
):
self.crash_dump_request_list.popleft()
def _handle_abort_req(self, recv_obj):
self.rid_to_state.pop(recv_obj.rid, None)
......@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func):
except Exception:
traceback = get_exception_traceback()
logger.error(f"TokenizerManager hit an exception: {traceback}")
if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
func.__self__.dump_requests_before_crash()
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1)
......@@ -1632,6 +1730,7 @@ class SignalHandler:
logger.error(
"Received sigquit from a child process. It usually means the child failed."
)
self.tokenizer_manager.dump_requests_before_crash()
kill_process_tree(os.getpid())
......
......@@ -123,6 +123,7 @@ class KVCache(abc.ABC):
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.mem_usage = 0
# used for chunked cpu-offloading
self.cpu_offloading_chunk_size = 8192
......@@ -219,6 +220,7 @@ class MHATokenToKVPool(KVCache):
logger.info(
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
)
self.mem_usage = (k_size + v_size) / GB
def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
......@@ -695,6 +697,7 @@ class MLATokenToKVPool(KVCache):
logger.info(
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
)
self.mem_usage = kv_size / GB
def get_kv_size_bytes(self):
assert hasattr(self, "kv_buffer")
......
......@@ -604,12 +604,13 @@ class ModelRunner:
self.dtype = self.model_config.dtype
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
self.weight_load_mem_usage = before_avail_memory - after_avail_memory
logger.info(
f"Load weight end. "
f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, "
f"avail mem={after_avail_memory:.2f} GB, "
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
f"mem usage={self.weight_load_mem_usage:.2f} GB."
)
# Handle the case where some ranks do not finish loading.
......@@ -1250,6 +1251,7 @@ class ModelRunner:
def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
self.cuda_graph_mem_usage = 0
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
......@@ -1265,9 +1267,10 @@ class ModelRunner:
)
self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem
logger.info(
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
)
def apply_torch_tp(self):
......
......@@ -534,6 +534,12 @@ class DummyModelLoader(BaseModelLoader):
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
if get_bool_env_var("SGL_CPU_QUANTIZATION"):
return load_model_with_cpu_quantization(
self, model_config=model_config, device_config=device_config
)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(
......@@ -1464,6 +1470,38 @@ class RemoteModelLoader(BaseModelLoader):
return model.eval()
def load_model_with_cpu_quantization(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
model = _initialize_model(
model_config,
self.load_config,
)
if not isinstance(self, DummyModelLoader):
model.load_weights(self._get_all_weights(model_config, model))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
model.to(target_device)
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
......
......@@ -13,7 +13,7 @@
# ==============================================================================
"""Inference-only Mistral model."""
from typing import List, Union
from typing import List
import torch
from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
......
......@@ -99,6 +99,7 @@ class ServerArgs:
log_level_http: Optional[str] = None
log_requests: bool = False
log_requests_level: int = 0
crash_dump_folder: Optional[str] = None
show_time_cost: bool = False
enable_metrics: bool = False
bucket_time_to_first_token: Optional[List[float]] = None
......@@ -927,8 +928,14 @@ class ServerArgs:
"--log-requests-level",
type=int,
default=0,
help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
choices=[0, 1, 2],
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
choices=[0, 1, 2, 3],
)
parser.add_argument(
"--crash-dump-folder",
type=str,
default=ServerArgs.crash_dump_folder,
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
)
parser.add_argument(
"--show-time-cost",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment