Unverified Commit 22352d47 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)


Co-authored-by: default avatarKan Wu <wukanustc@gmail.com>
parent c5131f7a
...@@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--log-level` | The logging level of all loggers. | info | | `--log-level` | The logging level of all loggers. | info |
| `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None | | `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None |
| `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False | | `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
| `--log-requests-level` | 0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output. | 0 | | `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | 0 |
| `--show-time-cost` | Show time cost of custom marks. | False | | `--show-time-cost` | Show time cost of custom marks. | False |
| `--enable-metrics` | Enable log prometheus metrics. | False | | `--enable-metrics` | Enable log prometheus metrics. | False |
| `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None | | `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None |
......
...@@ -38,6 +38,7 @@ class BenchArgs: ...@@ -38,6 +38,7 @@ class BenchArgs:
output_len: Tuple[int] = (16,) output_len: Tuple[int] = (16,)
temperature: float = 0.0 temperature: float = 0.0
return_logprob: bool = False return_logprob: bool = False
client_stream_interval: int = 1
input_len_step_percentage: float = 0.0 input_len_step_percentage: float = 0.0
result_filename: str = "result.jsonl" result_filename: str = "result.jsonl"
base_url: str = "" base_url: str = ""
...@@ -60,6 +61,11 @@ class BenchArgs: ...@@ -60,6 +61,11 @@ class BenchArgs:
) )
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument("--return-logprob", action="store_true") parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--client-stream-interval",
type=int,
default=BenchArgs.client_stream_interval,
)
parser.add_argument( parser.add_argument(
"--input-len-step-percentage", "--input-len-step-percentage",
type=float, type=float,
...@@ -120,6 +126,7 @@ def run_one_case( ...@@ -120,6 +126,7 @@ def run_one_case(
output_len: int, output_len: int,
temperature: float, temperature: float,
return_logprob: bool, return_logprob: bool,
stream_interval: int,
input_len_step_percentage: float, input_len_step_percentage: float,
run_name: str, run_name: str,
result_filename: str, result_filename: str,
...@@ -168,6 +175,7 @@ def run_one_case( ...@@ -168,6 +175,7 @@ def run_one_case(
"max_new_tokens": output_len, "max_new_tokens": output_len,
"ignore_eos": True, "ignore_eos": True,
"json_schema": json_schema, "json_schema": json_schema,
"stream_interval": stream_interval,
}, },
"return_logprob": return_logprob, "return_logprob": return_logprob,
"stream": True, "stream": True,
...@@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
else: else:
proc, base_url = launch_server_process(server_args) proc, base_url = launch_server_process(server_args)
tokenizer_id = server_args.tokenizer_path or server_args.model_path server_info = requests.get(base_url + "/get_server_info")
tokenizer = get_tokenizer(tokenizer_id) tokenizer_path = server_info.json()["tokenizer_path"]
tokenizer = get_tokenizer(tokenizer_path)
# warmup # warmup
if not bench_args.skip_warmup: if not bench_args.skip_warmup:
...@@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
output_len=16, output_len=16,
temperature=bench_args.temperature, temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob, return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage, input_len_step_percentage=bench_args.input_len_step_percentage,
run_name="", run_name="",
result_filename="", result_filename="",
...@@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol, ol,
temperature=bench_args.temperature, temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob, return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage, input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name, run_name=bench_args.run_name,
result_filename=bench_args.result_filename, result_filename=bench_args.result_filename,
...@@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol, ol,
temperature=bench_args.temperature, temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob, return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage, input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name, run_name=bench_args.run_name,
result_filename=bench_args.result_filename, result_filename=bench_args.result_filename,
......
...@@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url if args.base_url
else f"http://{args.host}:{args.port}/generate" else f"http://{args.host}:{args.port}/generate"
) )
args.apply_chat_template = True
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]: elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = ( api_url = (
f"{args.base_url}/v1/completions" f"{args.base_url}/v1/completions"
......
...@@ -147,12 +147,11 @@ class InternLM2Config(PretrainedConfig): ...@@ -147,12 +147,11 @@ class InternLM2Config(PretrainedConfig):
) )
if ( if (
rope_scaling_factor is None rope_scaling_factor is None
or not isinstance(rope_scaling_factor, float) or not isinstance(rope_scaling_factor, (float, int))
or not isinstance(rope_scaling_factor, int)
or rope_scaling_factor < 1.0 or rope_scaling_factor < 1.0
): ):
raise ValueError( raise ValueError(
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor}" f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
) )
if isinstance(rope_scaling_factor, int): if isinstance(rope_scaling_factor, int):
rope_scaling_factor = float(rope_scaling_factor) rope_scaling_factor = float(rope_scaling_factor)
......
...@@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState): ...@@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
@asynccontextmanager @asynccontextmanager
async def lifespan(fast_api_app: FastAPI): async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
# Initialize OpenAI serving handlers # Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager _global_state.tokenizer_manager, _global_state.template_manager
...@@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
_global_state.tokenizer_manager _global_state.tokenizer_manager
) )
server_args: ServerArgs = fast_api_app.server_args
if server_args.warmups is not None: if server_args.warmups is not None:
await execute_warmups( await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager server_args.disaggregation_mode,
server_args.warmups.split(","),
_global_state.tokenizer_manager,
) )
logger.info("Warmup ended") logger.info("Warmup ended")
...@@ -280,13 +281,17 @@ async def get_model_info(): ...@@ -280,13 +281,17 @@ async def get_model_info():
"model_path": _global_state.tokenizer_manager.model_path, "model_path": _global_state.tokenizer_manager.model_path,
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": _global_state.tokenizer_manager.is_generation, "is_generation": _global_state.tokenizer_manager.is_generation,
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
} }
return result return result
@app.get("/get_server_info") @app.get("/get_server_info")
async def get_server_info(): async def get_server_info():
internal_states = await _global_state.tokenizer_manager.get_internal_state() # Returns interna states per DP.
internal_states: List[Dict[Any, Any]] = (
await _global_state.tokenizer_manager.get_internal_state()
)
return { return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args), **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info, **_global_state.scheduler_info,
...@@ -300,6 +305,8 @@ async def get_load(): ...@@ -300,6 +305,8 @@ async def get_load():
return await _global_state.tokenizer_manager.get_load() return await _global_state.tokenizer_manager.get_load()
# example usage:
# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
@app.api_route("/set_internal_state", methods=["POST", "PUT"]) @app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request): async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj) res = await _global_state.tokenizer_manager.set_internal_state(obj)
...@@ -886,6 +893,15 @@ def launch_server( ...@@ -886,6 +893,15 @@ def launch_server(
add_prometheus_middleware(app) add_prometheus_middleware(app)
enable_func_timer() enable_func_timer()
image_token_text = None
if (
tokenizer_manager.image_token_id is not None
and not server_args.skip_tokenizer_init
):
image_token_text = tokenizer_manager.tokenizer.decode(
[tokenizer_manager.image_token_id]
)
# Send a warmup request - we will create the thread launch it # Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired. # in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread( warmup_thread = threading.Thread(
...@@ -893,7 +909,7 @@ def launch_server( ...@@ -893,7 +909,7 @@ def launch_server(
args=( args=(
server_args, server_args,
pipe_finish_writer, pipe_finish_writer,
_global_state.tokenizer_manager.image_token_id, image_token_text,
launch_callback, launch_callback,
), ),
) )
...@@ -1022,9 +1038,10 @@ def _wait_and_warmup( ...@@ -1022,9 +1038,10 @@ def _wait_and_warmup(
return return
# Debug print # Debug print
# logger.info(f"{res.json()=}") # logger.info(f"warmup request returns: {res.json()=}")
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("ready") pipe_finish_writer.send("ready")
......
...@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip ...@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fused_softcap_autotune = triton.autotune( fused_softcap_autotune = triton.autotune(
configs=[ configs=[
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
...@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal ...@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
assert x.shape == residual.shape and x.dtype == residual.dtype assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x) output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape bs, hidden_dim = x.shape
min_num_warps = 16 if _is_hip else 32
if autotune: if autotune:
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)]( fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
) )
else: else:
max_warps = 16 if _is_hip else 32
config = { config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max( "num_warps": max(
min( min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
),
4,
), ),
} }
...@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False): ...@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
else: else:
output = torch.empty_like(x) output = torch.empty_like(x)
bs, hidden_dim = x.shape bs, hidden_dim = x.shape
max_warps = 16 if _is_hip else 32
min_num_warps = 16 if _is_hip else 32
config = { config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max( "num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4 min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
), ),
} }
...@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm: ...@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
return self.rmsnorm2.forward_native(residual), residual return self.rmsnorm2.forward_native(residual), residual
@triton.jit
def experts_combine_kernel(
out_hidden_states,
moe_hidden_states,
mlp_hidden_states,
combine_k: tl.constexpr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
start_index_mlp = pid * hidden_dim
start_index_rmoe = pid * hidden_dim * combine_k
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
combine_k_offsets = tl.arange(0, combine_k)
moe_x = tl.load(
moe_hidden_states
+ start_index_rmoe
+ combine_k_offsets[:, None] * hidden_dim
+ offsets[None, :],
mask=mask[None, :],
other=0.0,
)
moe_x = tl.sum(moe_x, axis=0)
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
combined_x = (moe_x + mlp_x) / 1.4142135623730951
tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
assert moe_hidden_states.is_contiguous()
assert mlp_hidden_states.is_contiguous()
if len(moe_hidden_states.shape) == 2:
combine_k = 1 # pre-combined
else:
combine_k = moe_hidden_states.shape[1]
if output_buffer is None:
out_hidden_states = torch.empty_like(mlp_hidden_states)
else:
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
mlp_hidden_states.shape
)
bs, hidden_dim = mlp_hidden_states.shape
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
),
}
experts_combine_kernel[(bs,)](
out_hidden_states,
moe_hidden_states,
mlp_hidden_states,
combine_k,
hidden_dim,
**config,
)
return out_hidden_states
# gelu on first half of vector # gelu on first half of vector
@triton.jit @triton.jit
def gelu_and_mul_kernel( def gelu_and_mul_kernel(
...@@ -400,10 +463,11 @@ def gelu_and_mul_triton( ...@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
out_scales = scales out_scales = scales
static_scale = True static_scale = True
max_warps = 16 if _is_hip else 32
config = { config = {
# 8 ele per thread (not tuned) # 8 ele per thread (not tuned)
"num_warps": max( "num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4 min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
), ),
} }
......
from typing import Tuple from typing import Optional, Tuple
import torch import torch
import triton import triton
...@@ -16,6 +16,8 @@ def fused_moe_router_kernel( ...@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
moe_router_weight_ptr, # input (num_experts, hidden_dim) moe_router_weight_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk) topk_weights_ptr, # output (bs, topk)
topk_ids_ptr, # output (bs, topk) topk_ids_ptr, # output (bs, topk)
correction_bias_ptr,
is_correction_bias: tl.constexpr,
num_experts: tl.constexpr, num_experts: tl.constexpr,
topk: tl.constexpr, topk: tl.constexpr,
moe_softcapping: tl.constexpr, moe_softcapping: tl.constexpr,
...@@ -49,6 +51,11 @@ def fused_moe_router_kernel( ...@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
bottom = exped + 1 bottom = exped + 1
logits_softcapped = top / bottom * moe_softcapping logits_softcapped = top / bottom * moe_softcapping
# Add bias after softcapping
if is_correction_bias:
bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
logits_softcapped = logits_softcapped + bias
# topk # topk
# assert 1 <= topk <= num_experts # assert 1 <= topk <= num_experts
...@@ -109,6 +116,7 @@ def fused_moe_router_impl( ...@@ -109,6 +116,7 @@ def fused_moe_router_impl(
router_weight: torch.Tensor, router_weight: torch.Tensor,
topk: int, topk: int,
moe_softcapping: float, moe_softcapping: float,
correction_bias: Optional[torch.Tensor] = None,
): ):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1] assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape bs, hidden_dim = x.shape
...@@ -117,23 +125,23 @@ def fused_moe_router_impl( ...@@ -117,23 +125,23 @@ def fused_moe_router_impl(
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device) # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device) topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
is_correction_bias = correction_bias is not None
grid = lambda meta: (bs,) max_warps = 16 if _is_hip else 32
min_num_warps = 16 if _is_hip else 32
config = { config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max( "num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4 min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
), ),
} }
fused_moe_router_kernel[grid]( fused_moe_router_kernel[(bs,)](
x, x,
router_weight, router_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
correction_bias,
is_correction_bias=is_correction_bias,
num_experts=num_experts, num_experts=num_experts,
topk=topk, topk=topk,
moe_softcapping=moe_softcapping, moe_softcapping=moe_softcapping,
...@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel( ...@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
topk_ids_ptr, # output (bs, topk) topk_ids_ptr, # output (bs, topk)
bs, bs,
num_experts: tl.constexpr, num_experts: tl.constexpr,
topk: tl.constexpr, # only support topk == 1 topk: tl.constexpr, # only support topk <= 2
moe_softcapping: tl.constexpr, moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported moe_renormalize: tl.constexpr, # not supported
K: tl.constexpr, K: tl.constexpr,
...@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel( ...@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# 5. top1 # 5. top1
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1) cond_top1 = arange_block_size_n < num_experts
top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
top1_v = tl.max( top1_v = tl.max(
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
) )
invsumexp = 1.0 / tl.sum( top1_invsumexp = 1.0 / tl.sum(
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1 tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
) )
# 6. store to output # 6. store top1 to output
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
topk_mask = offs_topk < bs top1_mask = offs_top1 < bs * topk
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask) tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
tl.store( tl.store(
topk_weights_ptr + offs_topk, topk_weights_ptr + offs_top1,
invsumexp, top1_invsumexp,
mask=topk_mask, mask=top1_mask,
) )
# 7. handle topk == 2
if topk == 2:
cond_top2 = (arange_block_size_n < num_experts) and (
arange_block_size_n != top1[:, None]
)
top2 = tl.argmax(
tl.where(cond_top2, logits_softcapped, float("-inf")),
axis=1,
keep_dims=True,
)
top2_v = tl.sum(
logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
)
top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
# store top2
offs_top2 = (
pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
)
top2_mask = offs_top2 < bs * topk
tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
tl.store(
topk_weights_ptr + offs_top2,
top2_invsumexp,
mask=top2_mask,
)
def fused_moe_router_large_bs_impl( def fused_moe_router_large_bs_impl(
x: torch.Tensor, x: torch.Tensor,
...@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl( ...@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
assert num_experts <= BLOCK_SIZE_N assert num_experts <= BLOCK_SIZE_N
assert hidden_dim % BLOCK_SIZE_K == 0 assert hidden_dim % BLOCK_SIZE_K == 0
assert topk == 1 assert topk <= 2
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device) topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
...@@ -273,6 +309,7 @@ def fused_moe_router_shim( ...@@ -273,6 +309,7 @@ def fused_moe_router_shim(
gating_output, gating_output,
topk, topk,
renormalize, renormalize,
correction_bias: Optional[torch.Tensor] = None,
): ):
assert not renormalize assert not renormalize
assert ( assert (
...@@ -286,7 +323,7 @@ def fused_moe_router_shim( ...@@ -286,7 +323,7 @@ def fused_moe_router_shim(
BLOCK_SIZE_K = 256 BLOCK_SIZE_K = 256
if ( if (
bs >= 512 bs >= 512
and topk == 1 and topk <= 2
and num_experts <= BLOCK_SIZE_N and num_experts <= BLOCK_SIZE_N
and hidden_dim % BLOCK_SIZE_K == 0 and hidden_dim % BLOCK_SIZE_K == 0
): ):
...@@ -305,6 +342,7 @@ def fused_moe_router_shim( ...@@ -305,6 +342,7 @@ def fused_moe_router_shim(
router_weight=gating_output, router_weight=gating_output,
topk=topk, topk=topk,
moe_softcapping=moe_softcapping, moe_softcapping=moe_softcapping,
correction_bias=correction_bias,
) )
......
...@@ -28,7 +28,7 @@ if __name__ == "__main__": ...@@ -28,7 +28,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000") parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument("--log-requests", action="store_true") parser.add_argument("--log-requests", action="store_true")
parser.add_argument("--log-requests-level", type=int, default=2) parser.add_argument("--log-requests-level", type=int, default=3)
parser.add_argument( parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
) )
......
...@@ -516,9 +516,6 @@ class EmbeddingReqInput: ...@@ -516,9 +516,6 @@ class EmbeddingReqInput:
# For cross-encoder requests # For cross-encoder requests
is_cross_encoder_request: bool = False is_cross_encoder_request: bool = False
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided # at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None: if self.text is None and self.input_ids is None and self.image_data is None:
...@@ -572,6 +569,9 @@ class EmbeddingReqInput: ...@@ -572,6 +569,9 @@ class EmbeddingReqInput:
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
return self.rid return self.rid
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
def __getitem__(self, i): def __getitem__(self, i):
if self.is_cross_encoder_request: if self.is_cross_encoder_request:
return EmbeddingReqInput( return EmbeddingReqInput(
......
...@@ -2,12 +2,15 @@ ...@@ -2,12 +2,15 @@
Multi-modality utils Multi-modality utils
""" """
import hashlib
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import numpy as np
import torch import torch
from torch import nn from torch import nn
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
Modality, Modality,
MultimodalDataItem, MultimodalDataItem,
...@@ -678,3 +681,52 @@ def get_multimodal_data_bounds( ...@@ -678,3 +681,52 @@ def get_multimodal_data_bounds(
# Convert valid pairs to tensor # Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor return valid_pairs_tensor
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)
...@@ -3,7 +3,6 @@ import importlib ...@@ -3,7 +3,6 @@ import importlib
import inspect import inspect
import logging import logging
import pkgutil import pkgutil
from functools import lru_cache
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
......
...@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i ...@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import copy import copy
import dataclasses import dataclasses
import hashlib
import logging import logging
import threading import threading
from enum import Enum, auto from enum import Enum, auto
...@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ...@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
...@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"max_micro_batch_size", "max_micro_batch_size",
"disable_shared_experts_fusion", "disable_shared_experts_fusion",
"sampling_backend", "sampling_backend",
"speculative_accept_threshold_acc",
"speculative_accept_threshold_single", "speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"torchao_config", "torchao_config",
"triton_attention_reduce_in_fp32", "triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
...@@ -180,7 +178,9 @@ class Modality(Enum): ...@@ -180,7 +178,9 @@ class Modality(Enum):
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalDataItem: class MultimodalDataItem:
""" """
A single multimodal data, from a single image/video/audio or others. One MultimodalDataItem contains all inputs for one modality.
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
One for images and one for audio.
We put the common fields first and the model-specific fields last. We put the common fields first and the model-specific fields last.
""" """
...@@ -232,53 +232,7 @@ class MultimodalDataItem: ...@@ -232,53 +232,7 @@ class MultimodalDataItem:
""" """
Set the pad value after first hashing the data Set the pad value after first hashing the data
""" """
from sglang.srt.managers.mm_utils import hash_feature
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x
for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)
if self.precomputed_features is not None: if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features) self.hash = hash_feature(self.precomputed_features)
......
...@@ -418,14 +418,16 @@ class Scheduler( ...@@ -418,14 +418,16 @@ class Scheduler(
self.last_decode_stats_tic = time.perf_counter() self.last_decode_stats_tic = time.perf_counter()
self.last_prefill_stats_tic = time.perf_counter() self.last_prefill_stats_tic = time.perf_counter()
self.return_health_check_ct = 0 self.return_health_check_ct = 0
self.num_retracted_reqs: int = 0
self.num_paused_reqs: int = 0
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.sessions: Dict[str, Session] = {}
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu": if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU self.current_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None self.forward_sleep_time = None
# Init session info
self.sessions: Dict[str, Session] = {}
# Init chunked prefill # Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable if self.chunked_prefill_size <= 0: # -1 means disable
...@@ -473,26 +475,12 @@ class Scheduler( ...@@ -473,26 +475,12 @@ class Scheduler(
t = threading.Thread(target=self.watchdog_thread, daemon=True) t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start() t.start()
self.parent_process = psutil.Process().parent() self.parent_process = psutil.Process().parent()
# Init memory saver, profiler and metric stats
self.memory_saver_adapter = TorchMemorySaverAdapter.create( self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver enable=server_args.enable_memory_saver
) )
self.init_profier()
# Init profiler
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
# Init metrics stats
self.init_metrics() self.init_metrics()
self.init_kv_events(server_args.kv_events_config) self.init_kv_events(server_args.kv_events_config)
...@@ -526,6 +514,7 @@ class Scheduler( ...@@ -526,6 +514,7 @@ class Scheduler(
] ]
) )
# Init disaggregation
self.disaggregation_mode = DisaggregationMode( self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode self.server_args.disaggregation_mode
) )
...@@ -624,6 +613,21 @@ class Scheduler( ...@@ -624,6 +613,21 @@ class Scheduler(
) )
) )
def init_profier(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profile_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
def init_metrics(self): def init_metrics(self):
self.last_gen_throughput: float = 0.0 self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0 self.last_input_throughput: float = 0.0
...@@ -2107,6 +2111,18 @@ class Scheduler( ...@@ -2107,6 +2111,18 @@ class Scheduler(
def get_internal_state(self, recv_req: GetInternalStateReq): def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict) ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = {
"weight": round(
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
),
"kvcache": round(
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
),
"cuda_graph": round(
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
),
"token_capacity": int(self.max_total_num_tokens),
}
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = ( ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count self.cum_spec_accept_length / self.cum_spec_accept_count
......
...@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin: ...@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
stream_interval = ( stream_interval = (
req.sampling_params.stream_interval or self.stream_interval req.sampling_params.stream_interval or self.stream_interval
) )
should_output = len(req.output_ids) % stream_interval == 0 should_output = (
len(req.output_ids) % stream_interval == 1
if not self.model_config.is_multimodal_gen
and stream_interval > 1
else len(req.output_ids) % stream_interval == 0
)
else: else:
should_output = ( should_output = (
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0 len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
and not self.model_config.is_multimodal_gen if not self.model_config.is_multimodal_gen
else False
) )
if should_output: if should_output:
......
...@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput, UpdateWeightsFromTensorReqOutput,
) )
from sglang.srt.managers.multimodal_processor import ( from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
get_dummy_processor,
get_mm_processor,
import_processors,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -187,6 +183,8 @@ class TokenizerManager: ...@@ -187,6 +183,8 @@ class TokenizerManager:
if server_args.preferred_sampling_params if server_args.preferred_sampling_params
else None else None
) )
self.crash_dump_folder = server_args.crash_dump_folder
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
...@@ -251,10 +249,11 @@ class TokenizerManager: ...@@ -251,10 +249,11 @@ class TokenizerManager:
self.dump_requests_folder = "" # By default do not dump self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000 self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = [] self.dump_request_list: List[Tuple] = []
self.crash_dump_request_list: deque[Tuple] = deque()
self.log_request_metadata = self.get_log_request_metadata() self.log_request_metadata = self.get_log_request_metadata()
self.asyncio_tasks = set()
self.session_futures = {} # session_id -> asyncio event self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None self.max_req_input_len = None
self.asyncio_tasks = set()
# The event to notify the weight sync is finished. # The event to notify the weight sync is finished.
self.model_update_lock = RWLock() self.model_update_lock = RWLock()
...@@ -266,14 +265,14 @@ class TokenizerManager: ...@@ -266,14 +265,14 @@ class TokenizerManager:
self.disaggregation_mode = DisaggregationMode( self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode self.server_args.disaggregation_mode
) )
self.transfer_backend = TransferBackend( self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend self.server_args.disaggregation_transfer_backend
) )
# Start kv boostrap server on prefill # Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm # only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class( kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
) )
self.bootstrap_server = kv_bootstrap_server_class( self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port self.server_args.disaggregation_bootstrap_port
...@@ -324,7 +323,6 @@ class TokenizerManager: ...@@ -324,7 +323,6 @@ class TokenizerManager:
self.profile_communicator = _Communicator( self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
self.get_internal_state_communicator = _Communicator( self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -484,7 +482,7 @@ class TokenizerManager: ...@@ -484,7 +482,7 @@ class TokenizerManager:
token_type_ids = encoded.get("token_type_ids", [None])[0] token_type_ids = encoded.get("token_type_ids", [None])[0]
if self.mm_processor and obj.contains_mm_input(): if self.mm_processor and obj.contains_mm_input():
image_inputs = await self.mm_processor.process_mm_data_async( image_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data, image_data=obj.image_data,
input_text=input_text or input_ids, input_text=input_text or input_ids,
request_obj=obj, request_obj=obj,
...@@ -547,6 +545,14 @@ class TokenizerManager: ...@@ -547,6 +545,14 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature." "Please set `--enable-custom-logits-processor` to enable this feature."
) )
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
) -> None:
if any(id >= vocab_size for id in input_ids):
raise ValueError(
f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
)
def _create_tokenized_object( def _create_tokenized_object(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
...@@ -1096,12 +1102,36 @@ class TokenizerManager: ...@@ -1096,12 +1102,36 @@ class TokenizerManager:
"image_data", "image_data",
"audio_data", "audio_data",
"lora_path", "lora_path",
"sampling_params",
]
)
out_skip_names = set(
[
"text",
"output_ids",
] ]
) )
out_skip_names = set(["text", "output_ids", "embedding"])
elif self.log_requests_level == 1: elif self.log_requests_level == 1:
max_length = 2048 max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
elif self.log_requests_level == 2: elif self.log_requests_level == 2:
max_length = 2048
elif self.log_requests_level == 3:
max_length = 1 << 30 max_length = 1 << 30
else: else:
raise ValueError( raise ValueError(
...@@ -1118,6 +1148,8 @@ class TokenizerManager: ...@@ -1118,6 +1148,8 @@ class TokenizerManager:
self.dump_requests_folder = obj.dump_requests_folder self.dump_requests_folder = obj.dump_requests_folder
if obj.dump_requests_threshold is not None: if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold self.dump_requests_threshold = obj.dump_requests_threshold
if obj.crash_dump_folder is not None:
self.crash_dump_folder = obj.crash_dump_folder
logging.info(f"Config logging: {obj=}") logging.info(f"Config logging: {obj=}")
self.log_request_metadata = self.get_log_request_metadata() self.log_request_metadata = self.get_log_request_metadata()
...@@ -1166,6 +1198,52 @@ class TokenizerManager: ...@@ -1166,6 +1198,52 @@ class TokenizerManager:
loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
) )
def dump_requests_before_crash(self):
if self.crash_dump_performed:
logger.info(
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
if not self.crash_dump_folder:
return
data_to_dump = []
if self.crash_dump_request_list:
data_to_dump.extend(self.crash_dump_request_list)
# Add unfinished requests from rid_to_state
unfinished_requests = []
for rid, state in self.rid_to_state.items():
if not state.finished:
unfinished_requests.append(
(state.obj, {}, state.created_time, time.time())
)
if unfinished_requests:
data_to_dump.extend(unfinished_requests)
if not data_to_dump:
return
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
# Include server_args in the dump
data_to_dump_with_server_args = {
"server_args": self.server_args,
"requests": data_to_dump,
}
with open(filename, "wb") as f:
pickle.dump(data_to_dump_with_server_args, f)
logger.error(
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
)
async def sigterm_watchdog(self): async def sigterm_watchdog(self):
while not self.gracefully_exit: while not self.gracefully_exit:
await asyncio.sleep(5) await asyncio.sleep(5)
...@@ -1175,11 +1253,12 @@ class TokenizerManager: ...@@ -1175,11 +1253,12 @@ class TokenizerManager:
remain_num_req = len(self.rid_to_state) remain_num_req = len(self.rid_to_state)
if self.health_check_failed: if self.health_check_failed:
# if health check failed, exit immediately # if health check failed, we should exit immediately
logger.error( logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
remain_num_req, remain_num_req,
) )
self.dump_requests_before_crash()
break break
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"): elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
...@@ -1196,6 +1275,7 @@ class TokenizerManager: ...@@ -1196,6 +1275,7 @@ class TokenizerManager:
if remain_num_req > 0: if remain_num_req > 0:
await asyncio.sleep(5) await asyncio.sleep(5)
else: else:
self.dump_requests_before_crash()
break break
kill_process_tree(os.getpid(), include_parent=True) kill_process_tree(os.getpid(), include_parent=True)
...@@ -1273,16 +1353,7 @@ class TokenizerManager: ...@@ -1273,16 +1353,7 @@ class TokenizerManager:
"meta_info": meta_info, "meta_info": meta_info,
} }
elif isinstance(recv_obj, BatchMultimodalOut): elif isinstance(recv_obj, BatchMultimodalOut):
if isinstance(recv_obj.outputs[i], str): raise NotImplementedError("BatchMultimodalOut not implemented")
out_dict = {
"text": recv_obj.outputs[i],
"meta_info": meta_info,
}
else:
out_dict = {
"outputs": json.dumps(recv_obj.outputs[i]),
"meta_info": meta_info,
}
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = { out_dict = {
...@@ -1306,6 +1377,8 @@ class TokenizerManager: ...@@ -1306,6 +1377,8 @@ class TokenizerManager:
self.collect_metrics(state, recv_obj, i) self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished and state.obj.log_metrics: if self.dump_requests_folder and state.finished and state.obj.log_metrics:
self.dump_requests(state, out_dict) self.dump_requests(state, out_dict)
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
self.record_request_for_crash_dump(state, out_dict)
def convert_logprob_style( def convert_logprob_style(
self, self,
...@@ -1317,6 +1390,9 @@ class TokenizerManager: ...@@ -1317,6 +1390,9 @@ class TokenizerManager:
recv_obj: BatchStrOut, recv_obj: BatchStrOut,
recv_obj_index: int, recv_obj_index: int,
): ):
if recv_obj.input_token_logprobs_val is None:
return
if len(recv_obj.input_token_logprobs_val) > 0: if len(recv_obj.input_token_logprobs_val) > 0:
state.input_token_logprobs_val.extend( state.input_token_logprobs_val.extend(
recv_obj.input_token_logprobs_val[recv_obj_index] recv_obj.input_token_logprobs_val[recv_obj_index]
...@@ -1436,7 +1512,10 @@ class TokenizerManager: ...@@ -1436,7 +1512,10 @@ class TokenizerManager:
else 0 else 0
) )
if state.first_token_time == 0.0: if (
state.first_token_time == 0.0
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
state.first_token_time = state.last_time = time.time() state.first_token_time = state.last_time = time.time()
state.last_completion_tokens = completion_tokens state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token( self.metrics_collector.observe_time_to_first_token(
...@@ -1484,14 +1563,31 @@ class TokenizerManager: ...@@ -1484,14 +1563,31 @@ class TokenizerManager:
to_dump = self.dump_request_list to_dump = self.dump_request_list
self.dump_request_list = [] self.dump_request_list = []
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": to_dump,
}
def background_task(): def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True) os.makedirs(self.dump_requests_folder, exist_ok=True)
with open(filename, "wb") as f: with open(filename, "wb") as f:
pickle.dump(to_dump, f) pickle.dump(to_dump_with_server_args, f)
# Schedule the task to run in the background without awaiting it # Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task)) asyncio.create_task(asyncio.to_thread(background_task))
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
current_time = time.time()
self.crash_dump_request_list.append(
(state.obj, out_dict, state.created_time, current_time)
)
# Remove requests older than 5 minutes based on finish time
while (
self.crash_dump_request_list
and current_time - self.crash_dump_request_list[0][3] >= 300
):
self.crash_dump_request_list.popleft()
def _handle_abort_req(self, recv_obj): def _handle_abort_req(self, recv_obj):
self.rid_to_state.pop(recv_obj.rid, None) self.rid_to_state.pop(recv_obj.rid, None)
...@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func): ...@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func):
except Exception: except Exception:
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"TokenizerManager hit an exception: {traceback}") logger.error(f"TokenizerManager hit an exception: {traceback}")
if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
func.__self__.dump_requests_before_crash()
kill_process_tree(os.getpid(), include_parent=True) kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1) sys.exit(1)
...@@ -1632,6 +1730,7 @@ class SignalHandler: ...@@ -1632,6 +1730,7 @@ class SignalHandler:
logger.error( logger.error(
"Received sigquit from a child process. It usually means the child failed." "Received sigquit from a child process. It usually means the child failed."
) )
self.tokenizer_manager.dump_requests_before_crash()
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
......
...@@ -123,6 +123,7 @@ class KVCache(abc.ABC): ...@@ -123,6 +123,7 @@ class KVCache(abc.ABC):
self.memory_saver_adapter = TorchMemorySaverAdapter.create( self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver enable=enable_memory_saver
) )
self.mem_usage = 0
# used for chunked cpu-offloading # used for chunked cpu-offloading
self.cpu_offloading_chunk_size = 8192 self.cpu_offloading_chunk_size = 8192
...@@ -219,6 +220,7 @@ class MHATokenToKVPool(KVCache): ...@@ -219,6 +220,7 @@ class MHATokenToKVPool(KVCache):
logger.info( logger.info(
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB" f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
) )
self.mem_usage = (k_size + v_size) / GB
def _create_buffers(self): def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
...@@ -695,6 +697,7 @@ class MLATokenToKVPool(KVCache): ...@@ -695,6 +697,7 @@ class MLATokenToKVPool(KVCache):
logger.info( logger.info(
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB" f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
) )
self.mem_usage = kv_size / GB
def get_kv_size_bytes(self): def get_kv_size_bytes(self):
assert hasattr(self, "kv_buffer") assert hasattr(self, "kv_buffer")
......
...@@ -604,12 +604,13 @@ class ModelRunner: ...@@ -604,12 +604,13 @@ class ModelRunner:
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
self.weight_load_mem_usage = before_avail_memory - after_avail_memory
logger.info( logger.info(
f"Load weight end. " f"Load weight end. "
f"type={type(self.model).__name__}, " f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, " f"dtype={self.dtype}, "
f"avail mem={after_avail_memory:.2f} GB, " f"avail mem={after_avail_memory:.2f} GB, "
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB." f"mem usage={self.weight_load_mem_usage:.2f} GB."
) )
# Handle the case where some ranks do not finish loading. # Handle the case where some ranks do not finish loading.
...@@ -1250,6 +1251,7 @@ class ModelRunner: ...@@ -1250,6 +1251,7 @@ class ModelRunner:
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
self.cuda_graph_runner = None self.cuda_graph_runner = None
self.cuda_graph_mem_usage = 0
if not self.is_generation: if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
...@@ -1265,9 +1267,10 @@ class ModelRunner: ...@@ -1265,9 +1267,10 @@ class ModelRunner:
) )
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem
logger.info( logger.info(
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. " f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
) )
def apply_torch_tp(self): def apply_torch_tp(self):
......
...@@ -534,6 +534,12 @@ class DummyModelLoader(BaseModelLoader): ...@@ -534,6 +534,12 @@ class DummyModelLoader(BaseModelLoader):
model_config: ModelConfig, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
) -> nn.Module: ) -> nn.Module:
if get_bool_env_var("SGL_CPU_QUANTIZATION"):
return load_model_with_cpu_quantization(
self, model_config=model_config, device_config=device_config
)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model( model = _initialize_model(
...@@ -1464,6 +1470,38 @@ class RemoteModelLoader(BaseModelLoader): ...@@ -1464,6 +1470,38 @@ class RemoteModelLoader(BaseModelLoader):
return model.eval() return model.eval()
def load_model_with_cpu_quantization(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
model = _initialize_model(
model_config,
self.load_config,
)
if not isinstance(self, DummyModelLoader):
model.load_weights(self._get_all_weights(model_config, model))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
model.to(target_device)
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format.""" """Get a model loader based on the load format."""
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# ============================================================================== # ==============================================================================
"""Inference-only Mistral model.""" """Inference-only Mistral model."""
from typing import List, Union from typing import List
import torch import torch
from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
......
...@@ -99,6 +99,7 @@ class ServerArgs: ...@@ -99,6 +99,7 @@ class ServerArgs:
log_level_http: Optional[str] = None log_level_http: Optional[str] = None
log_requests: bool = False log_requests: bool = False
log_requests_level: int = 0 log_requests_level: int = 0
crash_dump_folder: Optional[str] = None
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = False enable_metrics: bool = False
bucket_time_to_first_token: Optional[List[float]] = None bucket_time_to_first_token: Optional[List[float]] = None
...@@ -927,8 +928,14 @@ class ServerArgs: ...@@ -927,8 +928,14 @@ class ServerArgs:
"--log-requests-level", "--log-requests-level",
type=int, type=int,
default=0, default=0,
help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.", help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
choices=[0, 1, 2], choices=[0, 1, 2, 3],
)
parser.add_argument(
"--crash-dump-folder",
type=str,
default=ServerArgs.crash_dump_folder,
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
) )
parser.add_argument( parser.add_argument(
"--show-time-cost", "--show-time-cost",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment