Unverified Commit e074d84e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] more code cleanup (#4077)

parent 4725e3f6
...@@ -40,6 +40,7 @@ runtime_common = [ ...@@ -40,6 +40,7 @@ runtime_common = [
"transformers==4.48.3", "transformers==4.48.3",
"llguidance>=0.6.15" "llguidance>=0.6.15"
] ]
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.0.3.post6", "sgl-kernel==0.0.3.post6",
......
...@@ -39,6 +39,7 @@ from transformers import ( ...@@ -39,6 +39,7 @@ from transformers import (
) )
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
ASSISTANT_SUFFIX = "Assistant:"
global args global args
...@@ -635,7 +636,11 @@ def sample_sharegpt_requests( ...@@ -635,7 +636,11 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompt = dataset[i][0] prompt = dataset[i][0]
if prompt_suffix: if prompt_suffix:
prompt = prompt prompt = (
remove_suffix(prompt, ASSISTANT_SUFFIX)
+ prompt_suffix
+ ASSISTANT_SUFFIX
)
if apply_chat_template: if apply_chat_template:
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
......
import json import json
import logging import logging
import re import re
from abc import ABC, abstractmethod
from json import JSONDecodeError, JSONDecoder from json import JSONDecodeError, JSONDecoder
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
......
import triton
import triton.language as tl
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
...@@ -33,6 +33,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -33,6 +33,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size, get_attention_dp_size,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
...@@ -152,6 +153,13 @@ class LogitsMetadata: ...@@ -152,6 +153,13 @@ class LogitsMetadata:
token_ids_logprobs=forward_batch.token_ids_logprobs, token_ids_logprobs=forward_batch.token_ids_logprobs,
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu, extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
padded_static_len=forward_batch.padded_static_len, padded_static_len=forward_batch.padded_static_len,
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
dp_local_start_pos=forward_batch.dp_local_start_pos,
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
gathered_buffer=forward_batch.gathered_buffer,
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
) )
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
...@@ -204,8 +212,6 @@ class LogitsProcessor(nn.Module): ...@@ -204,8 +212,6 @@ class LogitsProcessor(nn.Module):
): ):
self.final_logit_softcapping = None self.final_logit_softcapping = None
from sglang.srt.managers.schedule_batch import global_server_args_dict
self.debug_tensor_dump_output_folder = global_server_args_dict.get( self.debug_tensor_dump_output_folder = global_server_args_dict.get(
"debug_tensor_dump_output_folder", None "debug_tensor_dump_output_folder", None
) )
......
...@@ -212,6 +212,7 @@ class DetokenizerManager: ...@@ -212,6 +212,7 @@ class DetokenizerManager:
rids=recv_obj.rids, rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
output_ids=None,
prompt_tokens=recv_obj.prompt_tokens, prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens, completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
......
...@@ -414,6 +414,12 @@ class BatchTokenIDOut: ...@@ -414,6 +414,12 @@ class BatchTokenIDOut:
class BatchMultimodalDecodeReq: class BatchMultimodalDecodeReq:
# The request id # The request id
rids: List[str] rids: List[str]
finished_reasons: List[BaseFinishReason]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
@dataclass @dataclass
...@@ -424,6 +430,8 @@ class BatchStrOut: ...@@ -424,6 +430,8 @@ class BatchStrOut:
finished_reasons: List[dict] finished_reasons: List[dict]
# The output decoded strings # The output decoded strings
output_strs: List[str] output_strs: List[str]
# The token ids
output_ids: Optional[List[int]]
# Token counts # Token counts
prompt_tokens: List[int] prompt_tokens: List[int]
...@@ -453,6 +461,15 @@ class BatchStrOut: ...@@ -453,6 +461,15 @@ class BatchStrOut:
class BatchMultimodalOut: class BatchMultimodalOut:
# The request id # The request id
rids: List[str] rids: List[str]
# The finish reason
finished_reasons: List[dict]
# The outputs
outputs: List[List[Dict]]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
@dataclass @dataclass
......
...@@ -1141,7 +1141,7 @@ async def print_exception_wrapper(func): ...@@ -1141,7 +1141,7 @@ async def print_exception_wrapper(func):
class SignalHandler: class SignalHandler:
def __init__(self, tokenizer_manager): def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None): def signal_handler(self, signum=None, frame=None):
......
...@@ -192,7 +192,7 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -192,7 +192,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
k_size, v_size = self.get_kv_size_bytes() k_size, v_size = self.get_kv_size_bytes()
logger.info( logger.info(
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB." f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
) )
def _create_buffers(self): def _create_buffers(self):
......
...@@ -238,6 +238,9 @@ class CudaGraphRunner: ...@@ -238,6 +238,9 @@ class CudaGraphRunner:
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture # Capture
try: try:
...@@ -266,9 +269,9 @@ class CudaGraphRunner: ...@@ -266,9 +269,9 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention: if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max( min_num_tokens, max_num_tokens = min(
forward_batch.global_num_tokens forward_batch.global_num_tokens_cpu
) ), max(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs) (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
if self.disable_padding if self.disable_padding
...@@ -360,7 +363,7 @@ class CudaGraphRunner: ...@@ -360,7 +363,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens=global_num_tokens, global_num_tokens_cpu=global_num_tokens,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
...@@ -430,7 +433,7 @@ class CudaGraphRunner: ...@@ -430,7 +433,7 @@ class CudaGraphRunner:
# Pad # Pad
if self.enable_dp_attention: if self.enable_dp_attention:
index = bisect.bisect_left( index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens) self.capture_bs, max(forward_batch.global_num_tokens_cpu)
) )
else: else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
......
...@@ -190,7 +190,16 @@ class ForwardBatch: ...@@ -190,7 +190,16 @@ class ForwardBatch:
attn_backend: AttentionBackend = None attn_backend: AttentionBackend = None
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] = None global_num_tokens_cpu: Optional[List[int]] = None
global_num_tokens_gpu: Optional[torch.Tensor] = None
# Has to be None when cuda graph is captured.
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
...@@ -234,7 +243,6 @@ class ForwardBatch: ...@@ -234,7 +243,6 @@ class ForwardBatch:
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs, token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info, sampling_info=batch.sampling_info,
...@@ -248,8 +256,9 @@ class ForwardBatch: ...@@ -248,8 +256,9 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
) )
if ret.global_num_tokens is not None: if batch.global_num_tokens is not None:
max_len = max(ret.global_num_tokens) ret.global_num_tokens_cpu = batch.global_num_tokens
max_len = max(ret.global_num_tokens_cpu)
ret.gathered_buffer = torch.zeros( ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size), (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
dtype=model_runner.dtype, dtype=model_runner.dtype,
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# ============================================================================== # ==============================================================================
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
import collections
import datetime import datetime
import gc import gc
import json import json
...@@ -269,6 +268,7 @@ class ModelRunner: ...@@ -269,6 +268,7 @@ class ModelRunner:
elif self.device == "cpu": elif self.device == "cpu":
backend = "gloo" backend = "gloo"
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
if not self.server_args.enable_p2p_check: if not self.server_args.enable_p2p_check:
monkey_patch_p2p_access_check() monkey_patch_p2p_access_check()
...@@ -299,20 +299,24 @@ class ModelRunner: ...@@ -299,20 +299,24 @@ class ModelRunner:
min_per_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1 self.device, self.gpu_id, distributed=self.tp_size > 1
) )
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group() self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism # Check memory for tensor parallelism
if self.tp_size > 1: if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9: if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError( raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes." "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
) )
logger.info(
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
)
return min_per_gpu_memory return min_per_gpu_memory
def load_model(self): def load_model(self):
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
...@@ -382,11 +386,13 @@ class ModelRunner: ...@@ -382,11 +386,13 @@ class ModelRunner:
) )
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Load weight end. " f"Load weight end. "
f"type={type(self.model).__name__}, " f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, " f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"avail mem={after_avail_memory:.2f} GB, "
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
) )
def update_weights_from_disk( def update_weights_from_disk(
...@@ -785,12 +791,15 @@ class ModelRunner: ...@@ -785,12 +791,15 @@ class ModelRunner:
return return
tic = time.time() tic = time.time()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
) )
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
) )
def apply_torch_tp(self): def apply_torch_tp(self):
...@@ -806,8 +815,12 @@ class ModelRunner: ...@@ -806,8 +815,12 @@ class ModelRunner:
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
def forward_extend(self, forward_batch: ForwardBatch): def forward_extend(
self.attn_backend.init_forward_metadata(forward_batch) self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
):
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation: if self.is_generation:
if forward_batch.input_embeds is None: if forward_batch.input_embeds is None:
return self.model.forward( return self.model.forward(
......
...@@ -818,8 +818,8 @@ def all_gather( ...@@ -818,8 +818,8 @@ def all_gather(
if world_size == 1: if world_size == 1:
return input_tensor return input_tensor
all_lens = forward_batch.global_num_tokens all_lens = forward_batch.global_num_tokens_cpu
max_len = max(forward_batch.global_num_tokens) max_len = max(forward_batch.global_num_tokens_cpu)
padded_tensor = torch.nn.functional.pad( padded_tensor = torch.nn.functional.pad(
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0]) input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
......
...@@ -741,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1): ...@@ -741,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1):
return result return result
def first_rank_print(*args, **kwargs):
if torch.cuda.current_device() == 0:
print(*args, **kwargs)
else:
pass
def get_zmq_socket( def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
): ):
...@@ -1177,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool: ...@@ -1177,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
return value.lower() in ("true", "1") return value.lower() in ("true", "1")
@lru_cache(maxsize=2)
def disable_request_logging() -> bool:
return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int: def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for # Note: cuda_visible_devices is not used, but we keep it as an argument for
......
...@@ -85,6 +85,7 @@ nvcc_flags = [ ...@@ -85,6 +85,7 @@ nvcc_flags = [
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1", "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0", "-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"--ptxas-options=-v", "--ptxas-options=-v",
"--expt-relaxed-constexpr",
"-Xcompiler=-Wconversion", "-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing", "-Xcompiler=-fno-strict-aliasing",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment