Unverified Commit 66301e12 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve code styles (#4021)

parent ac238727
...@@ -30,6 +30,11 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -30,6 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [ elif config.architectures[0] in [
"Grok1ForCausalLM", "Grok1ForCausalLM",
"Grok1ImgGen", "Grok1ImgGen",
...@@ -39,11 +44,6 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -39,11 +44,6 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else: else:
# Default: Mixtral # Default: Mixtral
E = config.num_local_experts E = config.num_local_experts
......
...@@ -393,7 +393,7 @@ async def async_request_sglang_generate( ...@@ -393,7 +393,7 @@ async def async_request_sglang_generate(
output.itl.extend([adjust_itl] * num_new_tokens) output.itl.extend([adjust_itl] * num_new_tokens)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text = data["text"] last_output_len = output_len
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
......
...@@ -329,12 +329,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -329,12 +329,7 @@ class RuntimeEndpoint(BaseBackend):
def compute_normalized_prompt_logprobs(input_logprobs): def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]] values = [x[0] for x in input_logprobs if x[0]]
try:
return sum(values) / len(values) return sum(values) / len(values)
except TypeError:
print(f"{input_logprobs=}", flush=True)
print(f"{input_logprobs[0]=}", flush=True)
exit(-1)
class Runtime: class Runtime:
......
...@@ -64,7 +64,7 @@ class LogitsProcessorOutput: ...@@ -64,7 +64,7 @@ class LogitsProcessorOutput:
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logprobs of input tokens. shape: [#token] # The logprobs of input tokens. shape: [#token]
input_token_logprobs: torch.Tensor = None input_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
input_top_logprobs_val: List = None input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None input_top_logprobs_idx: List = None
......
...@@ -181,7 +181,6 @@ class EPMoE(torch.nn.Module): ...@@ -181,7 +181,6 @@ class EPMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu"
if self.grouped_gemm_runner is None: if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner( self.grouped_gemm_runner = GroupedGemmRunner(
......
...@@ -198,8 +198,6 @@ class DataParallelController: ...@@ -198,8 +198,6 @@ class DataParallelController:
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
print(f"{scheduler_info=}")
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req) self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
...@@ -222,7 +220,6 @@ class DataParallelController: ...@@ -222,7 +220,6 @@ class DataParallelController:
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
), ),
): ):
logger.info("dispatching")
self.dispatching(recv_req) self.dispatching(recv_req)
else: else:
# Send other control messages to first worker of tp group # Send other control messages to first worker of tp group
......
...@@ -158,7 +158,7 @@ class GenerateReqInput: ...@@ -158,7 +158,7 @@ class GenerateReqInput:
# Expand parallel_sample_num # Expand parallel_sample_num
num = self.batch_size * self.parallel_sample_num num = self.batch_size * self.parallel_sample_num
if self.image_data is None: if not self.image_data:
self.image_data = [None] * num self.image_data = [None] * num
elif not isinstance(self.image_data, list): elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num self.image_data = [self.image_data] * num
......
...@@ -282,6 +282,8 @@ class Req: ...@@ -282,6 +282,8 @@ class Req:
# If we want to abort the request in the middle of the event loop, set this to true # If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self.to_abort = False self.to_abort = False
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = "Unknown error"
self.stream = stream self.stream = stream
self.eos_token_ids = eos_token_ids self.eos_token_ids = eos_token_ids
...@@ -359,8 +361,6 @@ class Req: ...@@ -359,8 +361,6 @@ class Req:
# The tokens is prefilled but need to be considered as decode tokens # The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs # and should be updated for the decode logprobs
self.last_update_decode_tokens = 0 self.last_update_decode_tokens = 0
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
# Embedding (return values) # Embedding (return values)
self.embedding = None self.embedding = None
...@@ -377,9 +377,6 @@ class Req: ...@@ -377,9 +377,6 @@ class Req:
self.spec_verify_ct = 0 self.spec_verify_ct = 0
self.lora_path = lora_path self.lora_path = lora_path
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = "Unknown error"
@property @property
def seqlen(self): def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids) return len(self.origin_input_ids) + len(self.output_ids)
......
...@@ -358,7 +358,6 @@ class Scheduler: ...@@ -358,7 +358,6 @@ class Scheduler:
self.cum_spec_accept_count = 0 self.cum_spec_accept_count = 0
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.time()
self.return_health_check_ct = 0 self.return_health_check_ct = 0
self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu": if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU self.current_stream.synchronize = lambda: None # No-op for CPU
...@@ -444,11 +443,6 @@ class Scheduler: ...@@ -444,11 +443,6 @@ class Scheduler:
}, },
) )
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
# Init request dispatcher # Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
...@@ -2309,8 +2303,6 @@ def run_scheduler_process( ...@@ -2309,8 +2303,6 @@ def run_scheduler_process(
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
parent_process = psutil.Process().parent()
# Create a scheduler and run the event loop # Create a scheduler and run the event loop
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
......
...@@ -238,120 +238,6 @@ class TokenizerMetricsCollector: ...@@ -238,120 +238,6 @@ class TokenizerMetricsCollector:
], ],
) )
self.histogram_prefill_prealloc_duration = Histogram(
name="sglang:prefill_prealloc_duration_seconds",
documentation="Histogram of prefill prealloc duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
1,
2,
4,
6,
8,
10,
20,
40,
60,
80,
120,
160,
],
)
self.histogram_prefill_queue_duration = Histogram(
name="sglang:prefill_queue_duration_seconds",
documentation="Histogram of prefill queue duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_prefill_forward_duration = Histogram(
name="sglang:prefill_forward_duration_seconds",
documentation="Histogram of prefill forward duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_prefill_transfer_duration = Histogram(
name="sglang:prefill_transfer_duration_seconds",
documentation="Histogram of prefill transfer duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.050,
0.100,
0.150,
0.200,
0.300,
0.400,
0.500,
1.000,
2.000,
],
)
self.histogram_decode_prealloc_duration = Histogram(
name="sglang:decode_prealloc_duration_seconds",
documentation="Histogram of decode prealloc duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_decode_queue_duration = Histogram(
name="sglang:decode_queue_duration_seconds",
documentation="Histogram of decode queue duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
def _log_histogram(self, histogram, data: Union[int, float]) -> None: def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data) histogram.labels(**self.labels).observe(data)
......
...@@ -284,7 +284,9 @@ class ForwardBatch: ...@@ -284,7 +284,9 @@ class ForwardBatch:
): ):
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton( positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens ret.extend_prefix_lens,
ret.extend_seq_lens,
ret.extend_num_tokens,
) )
else: else:
positions, ret.extend_start_loc = compute_position_torch( positions, ret.extend_start_loc = compute_position_torch(
......
...@@ -62,7 +62,6 @@ class ServerArgs: ...@@ -62,7 +62,6 @@ class ServerArgs:
chat_template: Optional[str] = None chat_template: Optional[str] = None
is_embedding: bool = False is_embedding: bool = False
revision: Optional[str] = None revision: Optional[str] = None
skip_tokenizer_init: bool = False
# Port for the HTTP server # Port for the HTTP server
host: str = "127.0.0.1" host: str = "127.0.0.1"
...@@ -563,7 +562,7 @@ class ServerArgs: ...@@ -563,7 +562,7 @@ class ServerArgs:
"--download-dir", "--download-dir",
type=str, type=str,
default=ServerArgs.download_dir, default=ServerArgs.download_dir,
help="Model download directory.", help="Model download directory for huggingface.",
) )
parser.add_argument( parser.add_argument(
"--base-gpu-id", "--base-gpu-id",
......
...@@ -93,9 +93,11 @@ def run_eval(args): ...@@ -93,9 +93,11 @@ def run_eval(args):
tic = time.time() tic = time.time()
states = few_shot_gsm8k.run_batch( states = few_shot_gsm8k.run_batch(
arguments, arguments,
temperature=0, temperature=args.temperature if hasattr(args, "temperature") else 0,
num_threads=args.parallel, num_threads=args.parallel,
progress_bar=True, progress_bar=True,
return_logprob=getattr(args, "return_logprob", None),
logprob_start_len=getattr(args, "logprob_start_len", None),
) )
latency = time.time() - tic latency = time.time() - tic
...@@ -141,5 +143,6 @@ if __name__ == "__main__": ...@@ -141,5 +143,6 @@ if __name__ == "__main__":
parser.add_argument("--parallel", type=int, default=128) parser.add_argument("--parallel", type=int, default=128)
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000) parser.add_argument("--port", type=int, default=30000)
parser.add_argument("--temperature", type=float, default=0.0)
args = parser.parse_args() args = parser.parse_args()
run_eval(args) run_eval(args)
...@@ -8,16 +8,19 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"): ...@@ -8,16 +8,19 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12", "/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
mode=ctypes.RTLD_GLOBAL, mode=ctypes.RTLD_GLOBAL,
) )
from .version import __version__
if torch.version.hip is not None: from sgl_kernel.version import __version__
if torch.version.cuda:
from sgl_kernel.ops import ( from sgl_kernel.ops import (
all_reduce_reg,
all_reduce_unreg,
allocate_meta_buffer,
apply_rope_with_cos_sin_cache_inplace, apply_rope_with_cos_sin_cache_inplace,
bmm_fp8, bmm_fp8,
dispose, build_tree_kernel,
build_tree_kernel_efficient,
cublas_grouped_gemm,
custom_dispose,
custom_reduce,
fp8_blockwise_scaled_mm,
fp8_scaled_mm, fp8_scaled_mm,
fused_add_rmsnorm, fused_add_rmsnorm,
gelu_and_mul, gelu_and_mul,
...@@ -25,63 +28,32 @@ if torch.version.hip is not None: ...@@ -25,63 +28,32 @@ if torch.version.hip is not None:
gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm,
gemma_rmsnorm, gemma_rmsnorm,
get_graph_buffer_ipc_meta, get_graph_buffer_ipc_meta,
get_meta_buffer_ipc_handle, init_custom_reduce,
init_custom_ar,
int8_scaled_mm, int8_scaled_mm,
lightning_attention_decode, lightning_attention_decode,
meta_size,
min_p_sampling_from_probs, min_p_sampling_from_probs,
moe_align_block_size, moe_align_block_size,
register_buffer,
register_graph_buffers, register_graph_buffers,
rmsnorm, rmsnorm,
sampling_scaling_penalties, sampling_scaling_penalties,
sgl_per_token_group_quant_fp8,
silu_and_mul, silu_and_mul,
top_k_renorm_prob, top_k_renorm_prob,
top_k_top_p_sampling_from_probs, top_k_top_p_sampling_from_probs,
top_p_renorm_prob, top_p_renorm_prob,
tree_speculative_sampling_target_only,
) )
__all__ = [
"all_reduce_reg",
"all_reduce_unreg",
"allocate_meta_buffer",
"apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8",
"dispose",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
"get_meta_buffer_ipc_handle",
"init_custom_ar",
"int8_scaled_mm",
"lightning_attention_decode",
"meta_size",
"min_p_sampling_from_probs",
"moe_align_block_size",
"register_buffer",
"register_graph_buffers",
"rmsnorm",
"sampling_scaling_penalties",
"silu_and_mul",
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
"top_p_renorm_prob",
]
else: else:
assert torch.version.hip
from sgl_kernel.ops import ( from sgl_kernel.ops import (
all_reduce_reg,
all_reduce_unreg,
allocate_meta_buffer,
apply_rope_with_cos_sin_cache_inplace, apply_rope_with_cos_sin_cache_inplace,
bmm_fp8, bmm_fp8,
build_tree_kernel, dispose,
build_tree_kernel_efficient,
cublas_grouped_gemm,
custom_dispose,
custom_reduce,
fp8_blockwise_scaled_mm,
fp8_scaled_mm, fp8_scaled_mm,
fused_add_rmsnorm, fused_add_rmsnorm,
gelu_and_mul, gelu_and_mul,
...@@ -89,23 +61,26 @@ else: ...@@ -89,23 +61,26 @@ else:
gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm,
gemma_rmsnorm, gemma_rmsnorm,
get_graph_buffer_ipc_meta, get_graph_buffer_ipc_meta,
init_custom_reduce, get_meta_buffer_ipc_handle,
init_custom_ar,
int8_scaled_mm, int8_scaled_mm,
lightning_attention_decode, lightning_attention_decode,
meta_size,
min_p_sampling_from_probs, min_p_sampling_from_probs,
moe_align_block_size, moe_align_block_size,
register_buffer,
register_graph_buffers, register_graph_buffers,
rmsnorm, rmsnorm,
sampling_scaling_penalties, sampling_scaling_penalties,
sgl_per_token_group_quant_fp8,
silu_and_mul, silu_and_mul,
top_k_renorm_prob, top_k_renorm_prob,
top_k_top_p_sampling_from_probs, top_k_top_p_sampling_from_probs,
top_p_renorm_prob, top_p_renorm_prob,
tree_speculative_sampling_target_only,
) )
__all__ = [
__all__ = [
"__version__",
"apply_rope_with_cos_sin_cache_inplace", "apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8", "bmm_fp8",
"cublas_grouped_gemm", "cublas_grouped_gemm",
...@@ -135,4 +110,4 @@ else: ...@@ -135,4 +110,4 @@ else:
"top_k_top_p_sampling_from_probs", "top_k_top_p_sampling_from_probs",
"top_p_renorm_prob", "top_p_renorm_prob",
"tree_speculative_sampling_target_only", "tree_speculative_sampling_target_only",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment