Unverified Commit 2fc12995 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Remove unnecessary kernels of num_token_non_padded (#6965)

parent 20d3ad3b
...@@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
PPProxyTensors, PPProxyTensors,
enable_num_token_non_padded,
) )
from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
...@@ -190,6 +191,9 @@ class CudaGraphRunner: ...@@ -190,6 +191,9 @@ class CudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
)
self.speculative_algorithm = model_runner.server_args.speculative_algorithm self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
...@@ -327,9 +331,7 @@ class CudaGraphRunner: ...@@ -327,9 +331,7 @@ class CudaGraphRunner:
) )
is_tbo_supported = ( is_tbo_supported = (
forward_batch.can_run_tbo forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
if self.model_runner.server_args.enable_two_batch_overlap
else True
) )
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
...@@ -549,13 +551,7 @@ class CudaGraphRunner: ...@@ -549,13 +551,7 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions) self.positions[:raw_num_token].copy_(forward_batch.positions)
num_token_non_padded = len(forward_batch.input_ids)
self.num_token_non_padded[...] = num_token_non_padded
self.tbo_plugin.replay_prepare(
forward_mode=forward_batch.forward_mode,
bs=bs,
num_token_non_padded=num_token_non_padded,
)
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens_cpu.fill_(1) self.seq_lens_cpu.fill_(1)
...@@ -572,6 +568,14 @@ class CudaGraphRunner: ...@@ -572,6 +568,14 @@ class CudaGraphRunner:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention or self.enable_sp_layernorm: if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare(
forward_mode=forward_batch.forward_mode,
bs=bs,
num_token_non_padded=len(forward_batch.input_ids),
)
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
......
...@@ -118,6 +118,7 @@ class ForwardMode(IntEnum): ...@@ -118,6 +118,7 @@ class ForwardMode(IntEnum):
class CaptureHiddenMode(IntEnum): class CaptureHiddenMode(IntEnum):
# Do not capture anything.
NULL = auto() NULL = auto()
# Capture hidden states of all tokens. # Capture hidden states of all tokens.
FULL = auto() FULL = auto()
...@@ -253,6 +254,7 @@ class ForwardBatch: ...@@ -253,6 +254,7 @@ class ForwardBatch:
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
# For two-batch overlap
tbo_split_seq_index: Optional[int] = None tbo_split_seq_index: Optional[int] = None
tbo_parent_token_range: Optional[Tuple[int, int]] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List["ForwardBatch"]] = None tbo_children: Optional[List["ForwardBatch"]] = None
...@@ -265,12 +267,6 @@ class ForwardBatch: ...@@ -265,12 +267,6 @@ class ForwardBatch:
): ):
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens), batch_size=len(batch.seq_lens),
...@@ -284,6 +280,7 @@ class ForwardBatch: ...@@ -284,6 +280,7 @@ class ForwardBatch:
encoder_lens_cpu=batch.encoder_lens_cpu, encoder_lens_cpu=batch.encoder_lens_cpu,
encoder_out_cache_loc=batch.encoder_out_cache_loc, encoder_out_cache_loc=batch.encoder_out_cache_loc,
seq_lens_sum=batch.seq_lens_sum, seq_lens_sum=batch.seq_lens_sum,
seq_lens_cpu=batch.seq_lens_cpu,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs, token_ids_logprobs=batch.token_ids_logprobs,
...@@ -298,12 +295,19 @@ class ForwardBatch: ...@@ -298,12 +295,19 @@ class ForwardBatch:
spec_info=batch.spec_info, spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode, capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds, input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
num_token_non_padded=torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True),
tbo_split_seq_index=batch.tbo_split_seq_index, tbo_split_seq_index=batch.tbo_split_seq_index,
) )
device = model_runner.device
if batch.extend_input_logprob_token_ids is not None:
ret.extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
if enable_num_token_non_padded(model_runner.server_args):
ret.num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)
# For DP attention # For DP attention
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
...@@ -323,6 +327,7 @@ class ForwardBatch: ...@@ -323,6 +327,7 @@ class ForwardBatch:
dtype=model_runner.dtype, dtype=model_runner.dtype,
device=device, device=device,
) )
if ret.forward_mode.is_idle(): if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device) ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(ret) TboForwardBatchPreparer.prepare(ret)
...@@ -335,10 +340,6 @@ class ForwardBatch: ...@@ -335,10 +340,6 @@ class ForwardBatch:
): ):
ret.positions = ret.spec_info.positions ret.positions = ret.spec_info.positions
# Get seq_lens_cpu if needed
if ret.seq_lens_cpu is None:
ret.seq_lens_cpu = batch.seq_lens_cpu
# Init position information # Init position information
if ret.forward_mode.is_decode(): if ret.forward_mode.is_decode():
if ret.positions is None: if ret.positions is None:
...@@ -605,6 +606,10 @@ class ForwardBatch: ...@@ -605,6 +606,10 @@ class ForwardBatch:
return self.tbo_split_seq_index is not None return self.tbo_split_seq_index is not None
def enable_num_token_non_padded(server_args):
return server_args.enable_ep_moe or server_args.enable_deepep_moe
class PPProxyTensors: class PPProxyTensors:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103 # adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors: Dict[str, torch.Tensor] tensors: Dict[str, torch.Tensor]
......
...@@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin: ...@@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin:
def replay_prepare( def replay_prepare(
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
): ):
if not global_server_args_dict["enable_two_batch_overlap"]:
return
tbo_split_seq_index, tbo_split_token_index = ( tbo_split_seq_index, tbo_split_token_index = (
compute_split_indices_for_cuda_graph_replay( compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode, forward_mode=forward_mode,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment