Unverified Commit 2fc12995 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Remove unnecessary kernels of num_token_non_padded (#6965)

parent 20d3ad3b
......@@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
enable_num_token_non_padded,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
......@@ -190,6 +191,9 @@ class CudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
)
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
......@@ -327,9 +331,7 @@ class CudaGraphRunner:
)
is_tbo_supported = (
forward_batch.can_run_tbo
if self.model_runner.server_args.enable_two_batch_overlap
else True
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
)
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
......@@ -549,13 +551,7 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
num_token_non_padded = len(forward_batch.input_ids)
self.num_token_non_padded[...] = num_token_non_padded
self.tbo_plugin.replay_prepare(
forward_mode=forward_batch.forward_mode,
bs=bs,
num_token_non_padded=num_token_non_padded,
)
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
......@@ -572,6 +568,14 @@ class CudaGraphRunner:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare(
forward_mode=forward_batch.forward_mode,
bs=bs,
num_token_non_padded=len(forward_batch.input_ids),
)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
......
......@@ -118,6 +118,7 @@ class ForwardMode(IntEnum):
class CaptureHiddenMode(IntEnum):
# Do not capture anything.
NULL = auto()
# Capture hidden states of all tokens.
FULL = auto()
......@@ -253,6 +254,7 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None
# For two-batch overlap
tbo_split_seq_index: Optional[int] = None
tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List["ForwardBatch"]] = None
......@@ -265,12 +267,6 @@ class ForwardBatch:
):
from sglang.srt.two_batch_overlap import TboForwardBatchPreparer
device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
......@@ -284,6 +280,7 @@ class ForwardBatch:
encoder_lens_cpu=batch.encoder_lens_cpu,
encoder_out_cache_loc=batch.encoder_out_cache_loc,
seq_lens_sum=batch.seq_lens_sum,
seq_lens_cpu=batch.seq_lens_cpu,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
......@@ -298,12 +295,19 @@ class ForwardBatch:
spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
num_token_non_padded=torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True),
tbo_split_seq_index=batch.tbo_split_seq_index,
)
device = model_runner.device
if batch.extend_input_logprob_token_ids is not None:
ret.extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
if enable_num_token_non_padded(model_runner.server_args):
ret.num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)
# For DP attention
if batch.global_num_tokens is not None:
......@@ -323,6 +327,7 @@ class ForwardBatch:
dtype=model_runner.dtype,
device=device,
)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(ret)
......@@ -335,10 +340,6 @@ class ForwardBatch:
):
ret.positions = ret.spec_info.positions
# Get seq_lens_cpu if needed
if ret.seq_lens_cpu is None:
ret.seq_lens_cpu = batch.seq_lens_cpu
# Init position information
if ret.forward_mode.is_decode():
if ret.positions is None:
......@@ -605,6 +606,10 @@ class ForwardBatch:
return self.tbo_split_seq_index is not None
def enable_num_token_non_padded(server_args):
return server_args.enable_ep_moe or server_args.enable_deepep_moe
class PPProxyTensors:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors: Dict[str, torch.Tensor]
......
......@@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin:
def replay_prepare(
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
):
if not global_server_args_dict["enable_two_batch_overlap"]:
return
tbo_split_seq_index, tbo_split_token_index = (
compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment