Unverified Commit c0fb25e9 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

DP Enhancement (#8280)

parent 28d4d472
...@@ -545,6 +545,15 @@ class GroupCoordinator: ...@@ -545,6 +545,15 @@ class GroupCoordinator:
else: else:
torch.distributed.all_reduce(input_, group=self.device_group) torch.distributed.all_reduce(input_, group=self.device_group)
def reduce_scatter_tensor(
self,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
return output
def reduce_scatter( def reduce_scatter(
self, self,
output: torch.Tensor, output: torch.Tensor,
......
...@@ -65,7 +65,9 @@ class AttentionBackend(ABC): ...@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
**kwargs, **kwargs,
): ):
"""Run forward on an attention layer.""" """Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_idle():
return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
elif forward_batch.forward_mode.is_decode():
return self.forward_decode( return self.forward_decode(
q, q,
k, k,
......
...@@ -24,8 +24,8 @@ from sglang.srt.distributed import ( ...@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather, attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter, attn_tp_reduce_scatter_tensor,
dp_gather_partial, dp_gather_partial,
dp_scatter, dp_scatter,
get_attention_dp_size, get_attention_dp_size,
...@@ -309,8 +309,8 @@ class CommunicateSimpleFn: ...@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
) )
attn_tp_all_gather( attn_tp_all_gather_into_tensor(
list(hidden_states.tensor_split(context.attn_tp_size)), hidden_states,
local_hidden_states, local_hidden_states,
) )
return hidden_states return hidden_states
...@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
].clone(), ].clone(),
residual, residual,
) )
attn_tp_all_gather( attn_tp_all_gather_into_tensor(residual, local_residual)
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
if context.attn_dp_size != 1: if context.attn_dp_size != 1:
if context.attn_tp_rank == 0: if context.attn_tp_rank == 0:
hidden_states += residual hidden_states += residual
...@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
*, *,
residual_input_mode, residual_input_mode,
): ):
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size)) input_hidden_states = hidden_states
hidden_states = tensor_list[context.attn_tp_rank] hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
attn_tp_reduce_scatter(hidden_states, tensor_list) context.attn_tp_rank
]
attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
if residual_input_mode == ScatterMode.TP_ATTN_FULL: if residual_input_mode == ScatterMode.TP_ATTN_FULL:
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank] residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
...@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn: ...@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
) )
attn_tp_all_gather( attn_tp_all_gather_into_tensor(
list(hidden_states.tensor_split(context.attn_tp_size)), hidden_states,
local_hidden_states, local_hidden_states,
) )
return hidden_states, residual return hidden_states, residual
......
...@@ -3,7 +3,8 @@ from __future__ import annotations ...@@ -3,7 +3,8 @@ from __future__ import annotations
import functools import functools
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, List from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Tuple
import torch import torch
import triton import triton
...@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None ...@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None _LOCAL_ATTN_DP_RANK = None
class DPPaddingMode(IntEnum):
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
MAX_LEN = auto()
# Padding tokens to sum length and then gather tokens using `all_reduce`
SUM_LEN = auto()
def is_max_len(self):
return self == DPPaddingMode.MAX_LEN
def is_sum_len(self):
return self == DPPaddingMode.SUM_LEN
@classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
# we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens)
if sum_len * 2 > max_len * get_attention_dp_size():
return cls.MAX_LEN
else:
return cls.SUM_LEN
@classmethod
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
return cls.MAX_LEN
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention: if not enable_dp_attention:
return tp_rank, tp_size, 0 return tp_rank, tp_size, 0
...@@ -162,7 +191,7 @@ def disable_dp_size(): ...@@ -162,7 +191,7 @@ def disable_dp_size():
_ATTN_DP_SIZE = old_dp_size _ATTN_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch): def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here. # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
dp_rank = get_attention_dp_rank() dp_rank = get_attention_dp_rank()
...@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src): ...@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE) memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def _dp_gather( def _dp_gather_via_all_reduce(
global_tokens: torch.Tensor, global_tokens: torch.Tensor,
local_tokens: torch.Tensor, local_tokens: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
...@@ -238,13 +267,6 @@ def _dp_gather( ...@@ -238,13 +267,6 @@ def _dp_gather(
local_tokens.untyped_storage() is not global_tokens.untyped_storage() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed" ), "aliasing between global_tokens and local_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton( memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
) )
...@@ -263,6 +285,38 @@ def _dp_gather( ...@@ -263,6 +285,38 @@ def _dp_gather(
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens) global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def _dp_gather_via_all_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
if not is_partial:
if get_attention_tp_rank() != 0:
local_tokens.fill_(0)
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
get_attention_tp_rank()
]
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
def _dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
if forward_batch.dp_padding_mode.is_max_len():
_dp_gather_via_all_gather(
global_tokens, local_tokens, forward_batch, is_partial
)
else:
_dp_gather_via_all_reduce(
global_tokens, local_tokens, forward_batch, is_partial
)
def dp_gather_partial( def dp_gather_partial(
global_tokens: torch.Tensor, global_tokens: torch.Tensor,
local_tokens: torch.Tensor, local_tokens: torch.Tensor,
...@@ -296,24 +350,18 @@ def dp_scatter( ...@@ -296,24 +350,18 @@ def dp_scatter(
local_tokens.untyped_storage() is not global_tokens.untyped_storage() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed" ), "aliasing between local_tokens and global_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton( memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
) )
def attn_tp_reduce_scatter( def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
output: torch.Tensor, return get_attention_tp_group().reduce_scatter_tensor(output, input)
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list) def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().all_gather_into_tensor(output, input)
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list) return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
...@@ -27,7 +27,9 @@ from sglang.srt.distributed import ( ...@@ -27,7 +27,9 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
DPPaddingMode,
attn_tp_all_gather, attn_tp_all_gather,
attn_tp_all_gather_into_tensor,
dp_gather_replicate, dp_gather_replicate,
dp_scatter, dp_scatter,
get_attention_dp_rank, get_attention_dp_rank,
...@@ -111,7 +113,8 @@ class LogitsMetadata: ...@@ -111,7 +113,8 @@ class LogitsMetadata:
# Number of tokens to sample per DP rank # Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The gather mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None
# for padding # for padding
padded_static_len: int = -1 padded_static_len: int = -1
...@@ -163,12 +166,12 @@ class LogitsMetadata: ...@@ -163,12 +166,12 @@ class LogitsMetadata:
forward_batch_gathered_buffer=forward_batch.gathered_buffer, forward_batch_gathered_buffer=forward_batch.gathered_buffer,
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu, global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.SUM_LEN,
) )
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): def compute_dp_attention_metadata(self):
if self.global_num_tokens_for_logprob_cpu is None: # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
# we are capturing cuda graph # we may use a smaller buffer in draft extend.
return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank() dp_rank = get_attention_dp_rank()
...@@ -179,18 +182,9 @@ class LogitsMetadata: ...@@ -179,18 +182,9 @@ class LogitsMetadata:
else: else:
dp_local_start_pos = cumtokens[dp_rank - 1] dp_local_start_pos = cumtokens[dp_rank - 1]
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank] dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
gathered_buffer = torch.zeros(
(
sum(self.global_num_tokens_for_logprob_cpu),
hidden_states.shape[1],
),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
self.dp_local_start_pos = dp_local_start_pos self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens self.dp_local_num_tokens = dp_local_num_tokens
self.gathered_buffer = gathered_buffer
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
...@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module): ...@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module):
guarantee the given hidden_states follow this constraint. guarantee the given hidden_states follow this constraint.
""" """
if self.do_tensor_parallel_all_gather_dp_attn: if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata(hidden_states) logits_metadata.compute_dp_attention_metadata()
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
torch.empty_like(logits_metadata.gathered_buffer), torch.empty_like(logits_metadata.gathered_buffer),
hidden_states, hidden_states,
...@@ -463,6 +457,21 @@ class LogitsProcessor(nn.Module): ...@@ -463,6 +457,21 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
if self.use_attn_tp_group: if self.use_attn_tp_group:
if self.config.vocab_size % self.attn_tp_size == 0:
global_logits = torch.empty(
(
self.attn_tp_size,
logits.shape[0],
self.config.vocab_size // self.attn_tp_size,
),
device=logits.device,
dtype=logits.dtype,
)
attn_tp_all_gather_into_tensor(global_logits, logits)
global_logits = global_logits.permute(1, 0, 2).reshape(
logits.shape[0], self.config.vocab_size
)
else:
global_logits = torch.empty( global_logits = torch.empty(
(self.config.vocab_size, logits.shape[0]), (self.config.vocab_size, logits.shape[0]),
device=logits.device, device=logits.device,
...@@ -470,7 +479,8 @@ class LogitsProcessor(nn.Module): ...@@ -470,7 +479,8 @@ class LogitsProcessor(nn.Module):
) )
global_logits = global_logits.T global_logits = global_logits.T
attn_tp_all_gather( attn_tp_all_gather(
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits list(global_logits.tensor_split(self.attn_tp_size, dim=-1)),
logits,
) )
logits = global_logits logits = global_logits
else: else:
......
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Radix attention.""" """Radix attention."""
from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Optional from typing import TYPE_CHECKING, Optional
from torch import nn from torch import nn
from sglang.srt.layers.quantization.base_config import QuantizationConfig if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class AttentionType(Enum): class AttentionType(Enum):
......
...@@ -45,7 +45,6 @@ import triton ...@@ -45,7 +45,6 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
...@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs ...@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...@@ -1880,7 +1880,7 @@ class ModelWorkerBatch: ...@@ -1880,7 +1880,7 @@ class ModelWorkerBatch:
sampling_info: SamplingBatchInfo sampling_info: SamplingBatchInfo
# The input Embeds # The input Embeds
input_embeds: Optional[torch.tensor] = None input_embeds: Optional[torch.Tensor] = None
# For corss-encoder model # For corss-encoder model
token_type_ids: Optional[torch.Tensor] = None token_type_ids: Optional[torch.Tensor] = None
...@@ -1890,7 +1890,6 @@ class ModelWorkerBatch: ...@@ -1890,7 +1890,6 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None
hicache_consumer_index: int = 0 hicache_consumer_index: int = 0
# Overlap event # Overlap event
......
...@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile ...@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
...@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
# is very small. We add more values here to make sure we capture the maximum bs. # is very small. We add more values here to make sure we capture the maximum bs.
capture_bs += [model_runner.req_to_token_pool.size] capture_bs += [model_runner.req_to_token_pool.size]
mul_base = 1
if server_args.enable_two_batch_overlap: if server_args.enable_two_batch_overlap:
capture_bs = [bs for bs in capture_bs if bs % 2 == 0] mul_base *= 2
if require_gathered_buffer(server_args):
mul_base *= get_attention_tp_size()
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
if server_args.cuda_graph_max_bs: if server_args.cuda_graph_max_bs:
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
...@@ -306,20 +313,37 @@ class CudaGraphRunner: ...@@ -306,20 +313,37 @@ class CudaGraphRunner:
self.encoder_lens = None self.encoder_lens = None
if self.require_gathered_buffer: if self.require_gathered_buffer:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_num_token, self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size, self.model_runner.model_config.hidden_size,
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
else: else:
assert self.require_attn_tp_gather assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
self.custom_mask = torch.ones( self.custom_mask = torch.ones(
( (
...@@ -342,9 +366,9 @@ class CudaGraphRunner: ...@@ -342,9 +366,9 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
cuda_graph_bs = ( cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else max(forward_batch.global_num_tokens_cpu)
) )
else: else:
cuda_graph_bs = forward_batch.batch_size cuda_graph_bs = forward_batch.batch_size
...@@ -480,16 +504,19 @@ class CudaGraphRunner: ...@@ -480,16 +504,19 @@ class CudaGraphRunner:
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [num_tokens] * self.dp_size,
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32, dtype=torch.int32,
device=input_ids.device, device=input_ids.device,
) )
) )
global_num_tokens = self.global_num_tokens_gpu self.global_num_tokens_for_logprob_gpu.copy_(
gathered_buffer = self.gathered_buffer[:num_tokens] torch.tensor(
[num_tokens] * self.dp_size,
dtype=torch.int32,
device=input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
elif self.require_attn_tp_gather: elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
...@@ -498,10 +525,15 @@ class CudaGraphRunner: ...@@ -498,10 +525,15 @@ class CudaGraphRunner:
device=input_ids.device, device=input_ids.device,
) )
) )
global_num_tokens = self.global_num_tokens_gpu self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[:num_tokens] gathered_buffer = self.gathered_buffer[:num_tokens]
else: else:
global_num_tokens = None
gathered_buffer = None gathered_buffer = None
spec_info = self.get_spec_info(num_tokens) spec_info = self.get_spec_info(num_tokens)
...@@ -531,7 +563,9 @@ class CudaGraphRunner: ...@@ -531,7 +563,9 @@ class CudaGraphRunner:
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_gpu=global_num_tokens, global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
...@@ -635,12 +669,13 @@ class CudaGraphRunner: ...@@ -635,12 +669,13 @@ class CudaGraphRunner:
# Pad # Pad
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
total_batch_size = ( max_num_tokens = max(forward_batch.global_num_tokens_cpu)
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs max_batch_size = (
max_num_tokens / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else max_num_tokens
) )
index = bisect.bisect_left(self.capture_bs, total_batch_size) index = bisect.bisect_left(self.capture_bs, max_batch_size)
else: else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
...@@ -670,7 +705,8 @@ class CudaGraphRunner: ...@@ -670,7 +705,8 @@ class CudaGraphRunner:
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.require_gathered_buffer: if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
if enable_num_token_non_padded(self.model_runner.server_args): if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
if self.enable_two_batch_overlap: if self.enable_two_batch_overlap:
......
...@@ -38,6 +38,11 @@ import torch ...@@ -38,6 +38,11 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
get_attention_dp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import ( from sglang.srt.utils import (
flatten_nested_list, flatten_nested_list,
...@@ -48,6 +53,7 @@ from sglang.srt.utils import ( ...@@ -48,6 +53,7 @@ from sglang.srt.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -242,7 +248,7 @@ class ForwardBatch: ...@@ -242,7 +248,7 @@ class ForwardBatch:
lora_paths: Optional[List[str]] = None lora_paths: Optional[List[str]] = None
# For input embeddings # For input embeddings
input_embeds: Optional[torch.tensor] = None input_embeds: Optional[torch.Tensor] = None
# For cross-encoder model # For cross-encoder model
token_type_ids: Optional[torch.Tensor] = None token_type_ids: Optional[torch.Tensor] = None
...@@ -261,6 +267,8 @@ class ForwardBatch: ...@@ -261,6 +267,8 @@ class ForwardBatch:
# Has to be None when cuda graph is captured. # Has to be None when cuda graph is captured.
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The padding mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None
# for extend, local start pos and num tokens is different in logits processor # for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info # this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch # this will be recomputed in LogitsMetadata.from_forward_batch
...@@ -286,7 +294,7 @@ class ForwardBatch: ...@@ -286,7 +294,7 @@ class ForwardBatch:
# For two-batch overlap # For two-batch overlap
tbo_split_seq_index: Optional[int] = None tbo_split_seq_index: Optional[int] = None
tbo_parent_token_range: Optional[Tuple[int, int]] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List["ForwardBatch"]] = None tbo_children: Optional[List[ForwardBatch]] = None
@classmethod @classmethod
def init_new( def init_new(
...@@ -340,20 +348,38 @@ class ForwardBatch: ...@@ -340,20 +348,38 @@ class ForwardBatch:
len(batch.input_ids), dtype=torch.int32 len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
# For DP attention # For MLP sync
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
from sglang.srt.speculative.eagle_utils import (
spec_num_draft_tokens = ( EagleDraftInput,
batch.spec_num_draft_tokens EagleVerifyInput,
if batch.spec_num_draft_tokens is not None
else 1
) )
assert batch.global_num_tokens_for_logprob is not None
# process global_num_tokens and global_num_tokens_for_logprob
if batch.spec_info is not None:
if isinstance(batch.spec_info, EagleDraftInput):
global_num_tokens = [
x * batch.spec_info.num_tokens_per_batch
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.num_tokens_for_logprob_per_batch
for x in batch.global_num_tokens_for_logprob
]
else:
assert isinstance(batch.spec_info, EagleVerifyInput)
global_num_tokens = [ global_num_tokens = [
x * spec_num_draft_tokens for x in batch.global_num_tokens x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens
] ]
global_num_tokens_for_logprob = [ global_num_tokens_for_logprob = [
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens_for_logprob
] ]
else:
global_num_tokens = batch.global_num_tokens
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
ret.global_num_tokens_cpu = global_num_tokens ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor( ret.global_num_tokens_gpu = torch.tensor(
...@@ -365,15 +391,8 @@ class ForwardBatch: ...@@ -365,15 +391,8 @@ class ForwardBatch:
global_num_tokens_for_logprob, dtype=torch.int64 global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True) ).to(device, non_blocking=True)
sum_len = sum(global_num_tokens)
ret.gathered_buffer = torch.zeros(
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)
if ret.forward_mode.is_idle(): if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device) ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
TboForwardBatchPreparer.prepare( TboForwardBatchPreparer.prepare(
ret, is_draft_worker=model_runner.is_draft_worker ret, is_draft_worker=model_runner.is_draft_worker
) )
...@@ -573,6 +592,158 @@ class ForwardBatch: ...@@ -573,6 +592,158 @@ class ForwardBatch:
) )
self.prefix_chunk_kv_indices.append(chunk_kv_indices) self.prefix_chunk_kv_indices.append(chunk_kv_indices)
def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
if value == 0:
return torch.cat(
[tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])],
dim=0,
)
else:
return torch.cat(
[
tensor,
tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value),
],
dim=0,
)
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
from sglang.srt.speculative.eagle_utils import EagleDraftInput
assert self.global_num_tokens_cpu is not None
assert self.global_num_tokens_for_logprob_cpu is not None
global_num_tokens = self.global_num_tokens_cpu
sync_group_size = len(global_num_tokens)
attn_tp_size = get_attention_tp_size()
for i in range(sync_group_size):
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob
global_num_tokens[i] = (
(global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len():
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
# where transferred tokens should be padded to the same length.
max_num_tokens = max(global_num_tokens)
global_num_tokens = [max_num_tokens] * sync_group_size
buffer_len = max_num_tokens * sync_group_size
else:
buffer_len = sum(global_num_tokens)
self.gathered_buffer = torch.zeros(
(buffer_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=model_runner.device,
)
bs = self.batch_size
if len(global_num_tokens) > 1:
num_tokens = global_num_tokens[get_attention_dp_rank()]
else:
num_tokens = global_num_tokens[0]
# padding
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
seq_len_fill_value = (
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
self.seq_lens = self._pad_tensor_to_size(
self.seq_lens, bs, value=seq_len_fill_value
)
if self.seq_lens_cpu is not None:
self.seq_lens_cpu = self._pad_tensor_to_size(
self.seq_lens_cpu, bs, value=seq_len_fill_value
)
self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens)
if self.encoder_lens is not None:
self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
self.global_num_tokens_cpu = global_num_tokens
self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor(
global_num_tokens
)
if self.mrope_positions is not None:
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
if self.extend_seq_lens is not None:
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
spec_info = self.spec_info
self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states
if spec_info.topk_p is not None:
spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs)
if spec_info.topk_index is not None:
spec_info.topk_index = self._pad_tensor_to_size(
spec_info.topk_index, bs
)
if spec_info.accept_length is not None:
spec_info.accept_length = self._pad_tensor_to_size(
spec_info.accept_length, bs
)
spec_info.hidden_states = self._pad_tensor_to_size(
spec_info.hidden_states, num_tokens
)
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
bs = self.batch_size
if self.spec_info is not None:
if self.forward_mode.is_decode(): # draft
num_tokens = self.hidden_states_backup.shape[0]
self.positions = self.positions[:num_tokens]
self.seq_lens = self.seq_lens[:bs]
self.req_pool_indices = self.req_pool_indices[:bs]
if self.seq_lens_cpu is not None:
self.seq_lens_cpu = self.seq_lens_cpu[:bs]
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
elif self.forward_mode.is_target_verify(): # verify
num_tokens = bs * self.spec_info.draft_token_num
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
elif self.forward_mode.is_draft_extend(): # draft extend
self.spec_info.accept_length = self.spec_info.accept_length[:bs]
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
if hasattr(self, "hidden_states_backup"):
self.spec_info.hidden_states = self.hidden_states_backup
if hasattr(self, "output_cache_loc_backup"):
self.out_cache_loc = self.output_cache_loc_backup
elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states[:bs]
elif self.forward_mode.is_extend():
num_tokens = self.seq_lens_sum
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
# Here we suppose the length of each chunk is equal # Here we suppose the length of each chunk is equal
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256 # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# num_prefix_chunks = cdiv(1024, 256) = 4 # num_prefix_chunks = cdiv(1024, 256) = 4
......
...@@ -1464,8 +1464,12 @@ class ModelRunner: ...@@ -1464,8 +1464,12 @@ class ModelRunner:
tensor_parallel(self.model, device_mesh) tensor_parallel(self.model, device_mesh)
def forward_decode( def forward_decode(
self, forward_batch: ForwardBatch, pp_proxy_tensors=None self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors=None,
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch) self.attn_backend.init_forward_metadata(forward_batch)
# FIXME: add pp_proxy_tensors arg to all models # FIXME: add pp_proxy_tensors arg to all models
kwargs = {} kwargs = {}
...@@ -1578,8 +1582,18 @@ class ModelRunner: ...@@ -1578,8 +1582,18 @@ class ModelRunner:
skip_attn_backend_init=skip_attn_backend_init, skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
elif forward_batch.forward_mode.is_decode(): return ret, can_run_cuda_graph
ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
# For MLP sync
if forward_batch.global_num_tokens_cpu is not None:
forward_batch.prepare_mlp_sync_batch(self)
if forward_batch.forward_mode.is_decode():
ret = self.forward_decode(
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_extend(): elif forward_batch.forward_mode.is_extend():
ret = self.forward_extend( ret = self.forward_extend(
forward_batch, forward_batch,
...@@ -1597,6 +1611,9 @@ class ModelRunner: ...@@ -1597,6 +1611,9 @@ class ModelRunner:
else: else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
if forward_batch.global_num_tokens_cpu is not None:
forward_batch.post_forward_mlp_sync_batch(ret)
return ret, can_run_cuda_graph return ret, can_run_cuda_graph
def _preprocess_logits( def _preprocess_logits(
......
...@@ -550,9 +550,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -550,9 +550,8 @@ class DeepseekV2MoE(nn.Module):
def forward_deepep( def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor: ) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
shared_output = None shared_output = None
if is_non_idle_and_non_empty(forward_mode, hidden_states): if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
......
...@@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import ( ...@@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import (
ScatterMode, ScatterMode,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size, get_local_attention_dp_size,
......
...@@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo ...@@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size, get_local_attention_dp_size,
...@@ -193,8 +189,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -193,8 +189,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def forward_deepep( def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor: ) -> torch.Tensor:
forward_mode = forward_batch.forward_mode if hidden_states.shape[0] > 0:
if is_non_idle_and_non_empty(forward_mode, hidden_states):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
topk_weights, topk_idx, _ = self.topk( topk_weights, topk_idx, _ = self.topk(
......
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
from sglang.srt.layers.dp_attention import DPPaddingMode
from sglang.srt.model_executor.cuda_graph_runner import ( from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG, CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner, CudaGraphRunner,
...@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner: ...@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
) )
if self.require_gathered_buffer: if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros( self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
...@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner: ...@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else: else:
assert self.require_attn_tp_gather assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32 (1,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
# Capture # Capture
try: try:
...@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner: ...@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
cuda_graph_bs = ( cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else max(forward_batch.global_num_tokens_cpu)
) )
else: else:
cuda_graph_bs = forward_batch.batch_size cuda_graph_bs = forward_batch.batch_size
...@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner: ...@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [num_tokens] * self.dp_size,
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32, dtype=torch.int32,
device=self.input_ids.device, device=self.input_ids.device,
) )
) )
self.global_num_tokens_for_logprob_gpu.copy_( self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor( torch.tensor(
[ [num_tokens] * self.dp_size,
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32, dtype=torch.int32,
device=self.input_ids.device, device=self.input_ids.device,
) )
) )
global_num_tokens = self.global_num_tokens_gpu global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens] gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather: elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
...@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_gpu=global_num_tokens, global_num_tokens_gpu=global_num_tokens,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
...@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner: ...@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
# Pad # Pad
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
total_batch_size = ( max_num_tokens = max(forward_batch.global_num_tokens_cpu)
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs max_batch_size = (
max_num_tokens // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else max_num_tokens
) )
index = bisect.bisect_left(self.capture_bs, total_batch_size) index = bisect.bisect_left(self.capture_bs, max_batch_size)
else: else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
...@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner: ...@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# TODO(ch-wan): support num_token_non_padded
if self.require_gathered_buffer: if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.copy_( self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
# Attention backend # Attention backend
if bs != raw_bs: if bs != raw_bs:
......
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
from sglang.srt.layers.dp_attention import DPPaddingMode
from sglang.srt.model_executor.cuda_graph_runner import ( from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG, CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner, CudaGraphRunner,
...@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
) )
if self.require_gathered_buffer: if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros( self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
...@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else: else:
assert self.require_attn_tp_gather assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros( self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32 (1,), dtype=torch.int32
) )
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
# Capture # Capture
try: try:
with model_capture_mode(): with model_capture_mode():
...@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
cuda_graph_bs = ( cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else max(forward_batch.global_num_tokens_cpu)
) )
else: else:
cuda_graph_bs = forward_batch.seq_lens.numel() cuda_graph_bs = forward_batch.seq_lens.numel()
...@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [num_tokens] * self.dp_size,
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32, dtype=torch.int32,
device=self.input_ids.device, device=self.input_ids.device,
) )
) )
self.global_num_tokens_for_logprob_gpu.copy_( self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor( torch.tensor(
[ [bs] * self.dp_size,
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32, dtype=torch.int32,
device=self.input_ids.device, device=self.input_ids.device,
) )
) )
global_num_tokens = self.global_num_tokens_gpu gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather: elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
...@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
) )
self.global_num_tokens_for_logprob_gpu.copy_( self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor( torch.tensor(
[num_tokens], [bs],
dtype=torch.int32, dtype=torch.int32,
device=self.input_ids.device, device=self.input_ids.device,
) )
) )
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens] gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else: else:
global_num_tokens = None
gathered_buffer = None gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(), seq_lens_sum=seq_lens.sum().item(),
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_gpu=global_num_tokens, global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob, global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
...@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0] num_tokens = forward_batch.input_ids.shape[0]
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
total_batch_size = ( max_num_tokens = max(forward_batch.global_num_tokens_cpu)
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs max_batch_size = (
max_num_tokens // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else max_num_tokens
) )
index = bisect.bisect_left(self.capture_bs, total_batch_size) index = bisect.bisect_left(self.capture_bs, max_batch_size)
else: else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
...@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
# TODO(ch-wan): support num_token_non_padded
if self.require_gathered_buffer: if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.copy_( self.global_num_tokens_for_logprob_gpu.fill_(bs)
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
......
...@@ -71,9 +71,20 @@ class EagleDraftInput: ...@@ -71,9 +71,20 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None kv_indices: torch.Tensor = None
# Shape info for padding
num_tokens_per_batch: int = -1
num_tokens_for_logprob_per_batch: int = -1
# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch): def prepare_for_extend(self, batch: ScheduleBatch):
if batch.forward_mode.is_idle(): if batch.forward_mode.is_idle():
return return
# Prefill only generate 1 token. # Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens) assert len(self.verified_id) == len(batch.seq_lens)
...@@ -95,7 +106,7 @@ class EagleDraftInput: ...@@ -95,7 +106,7 @@ class EagleDraftInput:
capture_hidden_mode: CaptureHiddenMode, capture_hidden_mode: CaptureHiddenMode,
): ):
return cls( return cls(
verified_id=None, verified_id=torch.empty((0,), device=device, dtype=torch.int32),
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
...@@ -109,7 +120,10 @@ class EagleDraftInput: ...@@ -109,7 +120,10 @@ class EagleDraftInput:
batch: ScheduleBatch, batch: ScheduleBatch,
speculative_num_steps: int, speculative_num_steps: int,
): ):
batch.forward_mode = ForwardMode.DRAFT_EXTEND
if batch.forward_mode.is_idle():
return
batch.input_ids = self.verified_id batch.input_ids = self.verified_id
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens) batch.extend_num_tokens = sum(batch.extend_lens)
...@@ -316,7 +330,7 @@ class EagleVerifyInput: ...@@ -316,7 +330,7 @@ class EagleVerifyInput:
def verify( def verify(
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
logits_output: torch.Tensor, logits_output: LogitsProcessorOutput,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int, page_size: int,
vocab_mask: Optional[torch.Tensor] = None, # For grammar vocab_mask: Optional[torch.Tensor] = None, # For grammar
...@@ -599,13 +613,14 @@ class EagleVerifyInput: ...@@ -599,13 +613,14 @@ class EagleVerifyInput:
batch.out_cache_loc = tgt_cache_loc batch.out_cache_loc = tgt_cache_loc
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
draft_input = EagleDraftInput() draft_input = EagleDraftInput(
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index] hidden_states=batch.spec_info.hidden_states[accept_index],
draft_input.verified_id = verified_id verified_id=verified_id,
draft_input.accept_length = accept_length accept_length=accept_length,
draft_input.accept_length_cpu = accept_length.tolist() accept_length_cpu=accept_length.tolist(),
draft_input.seq_lens_for_draft_extend = batch.seq_lens seq_lens_for_draft_extend=batch.seq_lens,
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices req_pool_indices_for_draft_extend=batch.req_pool_indices,
)
return EagleVerifyOutput( return EagleVerifyOutput(
draft_input=draft_input, draft_input=draft_input,
...@@ -628,7 +643,6 @@ class EagleVerifyInput: ...@@ -628,7 +643,6 @@ class EagleVerifyInput:
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist() accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput()
if len(unfinished_accept_index) > 0: if len(unfinished_accept_index) > 0:
unfinished_accept_index = torch.cat(unfinished_accept_index) unfinished_accept_index = torch.cat(unfinished_accept_index)
unfinished_index_device = torch.tensor( unfinished_index_device = torch.tensor(
...@@ -659,18 +673,26 @@ class EagleVerifyInput: ...@@ -659,18 +673,26 @@ class EagleVerifyInput:
next_power_of_2(self.draft_token_num), next_power_of_2(self.draft_token_num),
) )
draft_input.hidden_states = batch.spec_info.hidden_states[ draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[
unfinished_accept_index unfinished_accept_index
] ],
draft_input.verified_id = predict[unfinished_accept_index] verified_id=predict[unfinished_accept_index],
draft_input.accept_length_cpu = draft_input_accept_length_cpu accept_length_cpu=draft_input_accept_length_cpu,
draft_input.accept_length = accept_length[unfinished_index_device] accept_length=accept_length[unfinished_index_device],
draft_input.seq_lens_for_draft_extend = batch.seq_lens[ seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
unfinished_index_device req_pool_indices_for_draft_extend=batch.req_pool_indices[
]
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
unfinished_index_device unfinished_index_device
] ],
)
else:
draft_input = EagleDraftInput.create_idle_input(
device=batch.device,
hidden_size=batch.model_config.hidden_size,
dtype=batch.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
return EagleVerifyOutput( return EagleVerifyOutput(
draft_input=draft_input, draft_input=draft_input,
......
...@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
def forward_batch_speculative_generation( def forward_batch_speculative_generation(
self, batch: ScheduleBatch self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int, int]: ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
"""Run speculative decoding forward. """Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that NOTE: Many states of batch is modified as you go through. It is not guaranteed that
...@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker): ...@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
self.verify(batch, spec_info) self.verify(batch, spec_info)
) )
if self.check_forward_draft_extend_after_decode(batch):
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode( # NOTE: We should use `check_forward_draft_extend_after_decode`
batch, # when DP attention is enabled, but it is slow. Skip it for now.
) if (
self.server_args.enable_dp_attention
or batch.spec_info.verified_id.shape[0] > 0
):
# decode is not finished
self.forward_draft_extend_after_decode(batch)
return ( return (
logits_output, logits_output,
verify_output.verified_id, verify_output.verified_id,
...@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
) )
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
local_need_forward = ( local_need_forward = batch.spec_info.verified_id.shape[0] > 0
batch.spec_info.verified_id is not None
and batch.spec_info.verified_id.shape[0] > 0
)
if not self.server_args.enable_dp_attention: if not self.server_args.enable_dp_attention:
return local_need_forward return local_need_forward
...@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
def forward_target_extend( def forward_target_extend(
self, batch: ScheduleBatch self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int]: ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
"""Run the target extend. """Run the target extend.
Args: Args:
...@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model. # We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.spec_num_draft_tokens = 1
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
...@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker): ...@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
self._draft_preprocess_decode(batch) self._draft_preprocess_decode(batch)
spec_info = batch.spec_info spec_info = batch.spec_info
assert isinstance(spec_info, EagleDraftInput)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
spec_info.num_tokens_per_batch = self.topk
spec_info.num_tokens_for_logprob_per_batch = self.topk
batch.return_hidden_states = False batch.return_hidden_states = False
# Get forward batch # Get forward batch
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.topk
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
...@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
forward_batch forward_batch
) )
else: else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
# Initialize attention backend # Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch) self.draft_attn_backend.init_forward_metadata(forward_batch)
...@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
def draft_forward(self, forward_batch: ForwardBatch): def draft_forward(self, forward_batch: ForwardBatch):
# Parse args # Parse args
spec_info = forward_batch.spec_info spec_info = forward_batch.spec_info
assert isinstance(spec_info, EagleDraftInput)
out_cache_loc = forward_batch.out_cache_loc out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = ( topk_p, topk_index, hidden_states = (
spec_info.topk_p, spec_info.topk_p,
...@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
spec_info.hidden_states = hidden_states spec_info.hidden_states = hidden_states
# Run forward # Run forward
logits_output = self.draft_model_runner.model.forward( logits_output, _ = self.draft_model_runner.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch, skip_attn_backend_init=True
) )
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1) probs = torch.softmax(logits_output.next_token_logits, dim=-1)
...@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
else ForwardMode.IDLE else ForwardMode.IDLE
) )
batch.spec_info = spec_info batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu seq_lens_cpu_cache=spec_info.seq_lens_cpu
) )
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar: if batch.has_grammar:
...@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
next_token_ids: List[int], next_token_ids: torch.Tensor,
seq_lens_cpu: torch.Tensor, seq_lens_cpu: Optional[torch.Tensor],
): ):
"""Run draft model extend. This API modifies the states of the batch. """Run draft model extend. This API modifies the states of the batch.
...@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info = EagleDraftInput( batch.spec_info = EagleDraftInput(
hidden_states=hidden_states, hidden_states=hidden_states,
verified_id=next_token_ids, verified_id=next_token_ids,
num_tokens_per_batch=1,
num_tokens_for_logprob_per_batch=1,
) )
batch.return_hidden_states = False batch.return_hidden_states = False
batch.spec_info.prepare_for_extend(batch) batch.spec_info.prepare_for_extend(batch)
...@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu seq_lens_cpu_cache=seq_lens_cpu
) )
model_worker_batch.spec_num_draft_tokens = 1
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -814,20 +820,16 @@ class EAGLEWorker(TpModelWorker): ...@@ -814,20 +820,16 @@ class EAGLEWorker(TpModelWorker):
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch): def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
assert isinstance(batch.spec_info, EagleDraftInput)
# Backup fields that will be modified in-place # Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone() seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob return_logprob_backup = batch.return_logprob
input_is_idle = batch.forward_mode.is_idle() input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
if batch.spec_info.verified_id is not None:
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
else:
batch = batch.copy() batch = batch.copy()
batch.prepare_for_idle() batch.prepare_for_idle()
hidden_size = ( hidden_size = (
...@@ -842,9 +844,21 @@ class EAGLEWorker(TpModelWorker): ...@@ -842,9 +844,21 @@ class EAGLEWorker(TpModelWorker):
topk=self.topk, topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
) )
batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
batch.spec_info.num_tokens_for_logprob_per_batch = 1
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
batch.forward_mode = (
ForwardMode.DRAFT_EXTEND
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)
batch.return_hidden_states = False batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
...@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker): ...@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
) )
forward_batch.spec_info.hidden_states = logits_output.hidden_states forward_batch.spec_info.hidden_states = logits_output.hidden_states
else: else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
self.draft_model_runner.attn_backend.init_forward_metadata( self.draft_model_runner.attn_backend.init_forward_metadata(
forward_batch forward_batch
) )
logits_output = self.draft_model_runner.model.forward( logits_output, _ = self.draft_model_runner.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch, skip_attn_backend_init=True
) )
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
......
...@@ -545,6 +545,7 @@ class TboForwardBatchPreparer: ...@@ -545,6 +545,7 @@ class TboForwardBatchPreparer:
tbo_children=None, tbo_children=None,
global_num_tokens_gpu=None, global_num_tokens_gpu=None,
global_num_tokens_cpu=None, global_num_tokens_cpu=None,
dp_padding_mode=None,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
global_num_tokens_for_logprob_gpu=None, global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None, global_num_tokens_for_logprob_cpu=None,
......
...@@ -35,7 +35,7 @@ class TestPureDP(CustomTestCase): ...@@ -35,7 +35,7 @@ class TestPureDP(CustomTestCase):
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"128", "128",
"--max-running-requests", "--max-running-requests",
"128", "512",
"--mem-fraction-static", "--mem-fraction-static",
"0.5", "0.5",
], ],
...@@ -81,7 +81,7 @@ class TestHybridDPTP(CustomTestCase): ...@@ -81,7 +81,7 @@ class TestHybridDPTP(CustomTestCase):
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"128", "128",
"--max-running-requests", "--max-running-requests",
"128", "256",
], ],
) )
...@@ -170,7 +170,7 @@ class TestNoGatherdBuffer(CustomTestCase): ...@@ -170,7 +170,7 @@ class TestNoGatherdBuffer(CustomTestCase):
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"32", "32",
"--max-running-requests", "--max-running-requests",
"128", "512",
], ],
) )
...@@ -217,7 +217,7 @@ class TestTBO(CustomTestCase): ...@@ -217,7 +217,7 @@ class TestTBO(CustomTestCase):
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"128", "128",
"--max-running-requests", "--max-running-requests",
"128", "512",
], ],
) )
...@@ -273,7 +273,7 @@ class TestMTP(CustomTestCase): ...@@ -273,7 +273,7 @@ class TestMTP(CustomTestCase):
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"32", "32",
"--max-running-requests", "--max-running-requests",
"32", "64",
], ],
) )
...@@ -343,7 +343,7 @@ class TestMTPWithTBO(CustomTestCase): ...@@ -343,7 +343,7 @@ class TestMTPWithTBO(CustomTestCase):
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"32", "32",
"--max-running-requests", "--max-running-requests",
"32", "128",
], ],
) )
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment