Unverified Commit e074d84e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] more code cleanup (#4077)

parent 4725e3f6
......@@ -40,6 +40,7 @@ runtime_common = [
"transformers==4.48.3",
"llguidance>=0.6.15"
]
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.0.3.post6",
......
......@@ -39,6 +39,7 @@ from transformers import (
)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
ASSISTANT_SUFFIX = "Assistant:"
global args
......@@ -635,7 +636,11 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
prompt = dataset[i][0]
if prompt_suffix:
prompt = prompt
prompt = (
remove_suffix(prompt, ASSISTANT_SUFFIX)
+ prompt_suffix
+ ASSISTANT_SUFFIX
)
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
......
import json
import logging
import re
from abc import ABC, abstractmethod
from json import JSONDecodeError, JSONDecoder
from typing import Any, Dict, List, Optional, Tuple
......
import triton
import triton.language as tl
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
......@@ -33,6 +33,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
......@@ -152,6 +153,13 @@ class LogitsMetadata:
token_ids_logprobs=forward_batch.token_ids_logprobs,
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
padded_static_len=forward_batch.padded_static_len,
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
dp_local_start_pos=forward_batch.dp_local_start_pos,
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
gathered_buffer=forward_batch.gathered_buffer,
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
)
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
......@@ -204,8 +212,6 @@ class LogitsProcessor(nn.Module):
):
self.final_logit_softcapping = None
from sglang.srt.managers.schedule_batch import global_server_args_dict
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
"debug_tensor_dump_output_folder", None
)
......
......@@ -212,6 +212,7 @@ class DetokenizerManager:
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
output_ids=None,
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
......
......@@ -414,6 +414,12 @@ class BatchTokenIDOut:
class BatchMultimodalDecodeReq:
# The request id
rids: List[str]
finished_reasons: List[BaseFinishReason]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
@dataclass
......@@ -424,6 +430,8 @@ class BatchStrOut:
finished_reasons: List[dict]
# The output decoded strings
output_strs: List[str]
# The token ids
output_ids: Optional[List[int]]
# Token counts
prompt_tokens: List[int]
......@@ -453,6 +461,15 @@ class BatchStrOut:
class BatchMultimodalOut:
# The request id
rids: List[str]
# The finish reason
finished_reasons: List[dict]
# The outputs
outputs: List[List[Dict]]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
@dataclass
......
......@@ -1141,7 +1141,7 @@ async def print_exception_wrapper(func):
class SignalHandler:
def __init__(self, tokenizer_manager):
def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None):
......
......@@ -192,7 +192,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
k_size, v_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
)
def _create_buffers(self):
......
......@@ -238,6 +238,9 @@ class CudaGraphRunner:
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture
try:
......@@ -266,9 +269,9 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
)
min_num_tokens, max_num_tokens = min(
forward_batch.global_num_tokens_cpu
), max(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
if self.disable_padding
......@@ -360,7 +363,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,
global_num_tokens=global_num_tokens,
global_num_tokens_cpu=global_num_tokens,
gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
......@@ -430,7 +433,7 @@ class CudaGraphRunner:
# Pad
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens)
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
......
......@@ -190,7 +190,16 @@ class ForwardBatch:
attn_backend: AttentionBackend = None
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_cpu: Optional[List[int]] = None
global_num_tokens_gpu: Optional[torch.Tensor] = None
# Has to be None when cuda graph is captured.
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False
......@@ -234,7 +243,6 @@ class ForwardBatch:
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
......@@ -248,8 +256,9 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
)
if ret.global_num_tokens is not None:
max_len = max(ret.global_num_tokens)
if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
max_len = max(ret.global_num_tokens_cpu)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
......
......@@ -13,7 +13,6 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import collections
import datetime
import gc
import json
......@@ -269,6 +268,7 @@ class ModelRunner:
elif self.device == "cpu":
backend = "gloo"
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
if not self.server_args.enable_p2p_check:
monkey_patch_p2p_access_check()
......@@ -299,20 +299,24 @@ class ModelRunner:
min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1
)
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism
if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
logger.info(
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
)
return min_per_gpu_memory
def load_model(self):
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
......@@ -382,11 +386,13 @@ class ModelRunner:
)
self.dtype = self.model_config.dtype
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Load weight end. "
f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"avail mem={after_avail_memory:.2f} GB, "
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
)
def update_weights_from_disk(
......@@ -785,12 +791,15 @@ class ModelRunner:
return
tic = time.time()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
def apply_torch_tp(self):
......@@ -806,8 +815,12 @@ class ModelRunner:
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
def forward_extend(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
):
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
if forward_batch.input_embeds is None:
return self.model.forward(
......
......@@ -818,8 +818,8 @@ def all_gather(
if world_size == 1:
return input_tensor
all_lens = forward_batch.global_num_tokens
max_len = max(forward_batch.global_num_tokens)
all_lens = forward_batch.global_num_tokens_cpu
max_len = max(forward_batch.global_num_tokens_cpu)
padded_tensor = torch.nn.functional.pad(
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
......
......@@ -741,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1):
return result
def first_rank_print(*args, **kwargs):
if torch.cuda.current_device() == 0:
print(*args, **kwargs)
else:
pass
def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
):
......@@ -1177,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
return value.lower() in ("true", "1")
@lru_cache(maxsize=2)
def disable_request_logging() -> bool:
return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
......
......@@ -85,6 +85,7 @@ nvcc_flags = [
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"--ptxas-options=-v",
"--expt-relaxed-constexpr",
"-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing",
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment