"docs/source/vscode:/vscode.git/clone" did not exist on "ca265374bb0c8b01babf4d59c00ba5ad20adfb5a"
Unverified Commit c0fb25e9 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

DP Enhancement (#8280)

parent 28d4d472
......@@ -545,6 +545,15 @@ class GroupCoordinator:
else:
torch.distributed.all_reduce(input_, group=self.device_group)
def reduce_scatter_tensor(
self,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
return output
def reduce_scatter(
self,
output: torch.Tensor,
......
......@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
**kwargs,
):
"""Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode():
if forward_batch.forward_mode.is_idle():
return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
elif forward_batch.forward_mode.is_decode():
return self.forward_decode(
q,
k,
......
......@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
......@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
attn_tp_all_gather_into_tensor(
hidden_states,
local_hidden_states,
)
return hidden_states
......@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
].clone(),
residual,
)
attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
attn_tp_all_gather_into_tensor(residual, local_residual)
if context.attn_dp_size != 1:
if context.attn_tp_rank == 0:
hidden_states += residual
......@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
*,
residual_input_mode,
):
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
input_hidden_states = hidden_states
hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
context.attn_tp_rank
]
attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
if hidden_states.shape[0] != 0:
......@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(context.attn_tp_size)),
attn_tp_all_gather_into_tensor(
hidden_states,
local_hidden_states,
)
return hidden_states, residual
......
......@@ -3,7 +3,8 @@ from __future__ import annotations
import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, List
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Tuple
import torch
import triton
......@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None
class DPPaddingMode(IntEnum):
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
MAX_LEN = auto()
# Padding tokens to sum length and then gather tokens using `all_reduce`
SUM_LEN = auto()
def is_max_len(self):
return self == DPPaddingMode.MAX_LEN
def is_sum_len(self):
return self == DPPaddingMode.SUM_LEN
@classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
# we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens)
if sum_len * 2 > max_len * get_attention_dp_size():
return cls.MAX_LEN
else:
return cls.SUM_LEN
@classmethod
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
return cls.MAX_LEN
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
......@@ -162,7 +191,7 @@ def disable_dp_size():
_ATTN_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch):
def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
dp_rank = get_attention_dp_rank()
......@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def _dp_gather(
def _dp_gather_via_all_reduce(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
......@@ -238,13 +267,6 @@ def _dp_gather(
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
......@@ -263,6 +285,38 @@ def _dp_gather(
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def _dp_gather_via_all_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
if not is_partial:
if get_attention_tp_rank() != 0:
local_tokens.fill_(0)
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
get_attention_tp_rank()
]
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
def _dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
if forward_batch.dp_padding_mode.is_max_len():
_dp_gather_via_all_gather(
global_tokens, local_tokens, forward_batch, is_partial
)
else:
_dp_gather_via_all_reduce(
global_tokens, local_tokens, forward_batch, is_partial
)
def dp_gather_partial(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
......@@ -296,24 +350,18 @@ def dp_scatter(
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def attn_tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().reduce_scatter_tensor(output, input)
def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().all_gather_into_tensor(output, input)
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
......@@ -27,7 +27,9 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
attn_tp_all_gather,
attn_tp_all_gather_into_tensor,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
......@@ -111,7 +113,8 @@ class LogitsMetadata:
# Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The gather mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None
# for padding
padded_static_len: int = -1
......@@ -163,12 +166,12 @@ class LogitsMetadata:
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.SUM_LEN,
)
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
if self.global_num_tokens_for_logprob_cpu is None:
# we are capturing cuda graph
return
def compute_dp_attention_metadata(self):
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
# we may use a smaller buffer in draft extend.
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank()
......@@ -179,18 +182,9 @@ class LogitsMetadata:
else:
dp_local_start_pos = cumtokens[dp_rank - 1]
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
gathered_buffer = torch.zeros(
(
sum(self.global_num_tokens_for_logprob_cpu),
hidden_states.shape[1],
),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens
self.gathered_buffer = gathered_buffer
class LogitsProcessor(nn.Module):
......@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module):
guarantee the given hidden_states follow this constraint.
"""
if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata(hidden_states)
logits_metadata.compute_dp_attention_metadata()
hidden_states, local_hidden_states = (
torch.empty_like(logits_metadata.gathered_buffer),
hidden_states,
......@@ -463,15 +457,31 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather:
if self.use_attn_tp_group:
global_logits = torch.empty(
(self.config.vocab_size, logits.shape[0]),
device=logits.device,
dtype=logits.dtype,
)
global_logits = global_logits.T
attn_tp_all_gather(
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
)
if self.config.vocab_size % self.attn_tp_size == 0:
global_logits = torch.empty(
(
self.attn_tp_size,
logits.shape[0],
self.config.vocab_size // self.attn_tp_size,
),
device=logits.device,
dtype=logits.dtype,
)
attn_tp_all_gather_into_tensor(global_logits, logits)
global_logits = global_logits.permute(1, 0, 2).reshape(
logits.shape[0], self.config.vocab_size
)
else:
global_logits = torch.empty(
(self.config.vocab_size, logits.shape[0]),
device=logits.device,
dtype=logits.dtype,
)
global_logits = global_logits.T
attn_tp_all_gather(
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)),
logits,
)
logits = global_logits
else:
logits = tensor_model_parallel_all_gather(logits)
......
......@@ -12,14 +12,16 @@
# limitations under the License.
# ==============================================================================
"""Radix attention."""
from __future__ import annotations
from enum import Enum
from typing import Optional
from typing import TYPE_CHECKING, Optional
from torch import nn
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class AttentionType(Enum):
......
......@@ -45,7 +45,6 @@ import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
......@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
......@@ -1880,7 +1880,7 @@ class ModelWorkerBatch:
sampling_info: SamplingBatchInfo
# The input Embeds
input_embeds: Optional[torch.tensor] = None
input_embeds: Optional[torch.Tensor] = None
# For corss-encoder model
token_type_ids: Optional[torch.Tensor] = None
......@@ -1890,7 +1890,6 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None
hicache_consumer_index: int = 0
# Overlap event
......
......@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
......@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs += [model_runner.req_to_token_pool.size]
mul_base = 1
if server_args.enable_two_batch_overlap:
capture_bs = [bs for bs in capture_bs if bs % 2 == 0]
mul_base *= 2
if require_gathered_buffer(server_args):
mul_base *= get_attention_tp_size()
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
if server_args.cuda_graph_max_bs:
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
......@@ -306,20 +313,37 @@ class CudaGraphRunner:
self.encoder_lens = None
if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
self.custom_mask = torch.ones(
(
......@@ -342,9 +366,9 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
else max(forward_batch.global_num_tokens_cpu)
)
else:
cuda_graph_bs = forward_batch.batch_size
......@@ -480,16 +504,19 @@ class CudaGraphRunner:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
[num_tokens] * self.dp_size,
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens] * self.dp_size,
dtype=torch.int32,
device=input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
......@@ -498,10 +525,15 @@ class CudaGraphRunner:
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[:num_tokens]
else:
global_num_tokens = None
gathered_buffer = None
spec_info = self.get_spec_info(num_tokens)
......@@ -531,7 +563,9 @@ class CudaGraphRunner:
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
......@@ -635,12 +669,13 @@ class CudaGraphRunner:
# Pad
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
max_batch_size = (
max_num_tokens / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
else max_num_tokens
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
index = bisect.bisect_left(self.capture_bs, max_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
......@@ -670,7 +705,8 @@ class CudaGraphRunner:
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
if self.enable_two_batch_overlap:
......
......@@ -38,6 +38,11 @@ import torch
import triton
import triton.language as tl
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
get_attention_dp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import (
flatten_nested_list,
......@@ -48,6 +53,7 @@ from sglang.srt.utils import (
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -242,7 +248,7 @@ class ForwardBatch:
lora_paths: Optional[List[str]] = None
# For input embeddings
input_embeds: Optional[torch.tensor] = None
input_embeds: Optional[torch.Tensor] = None
# For cross-encoder model
token_type_ids: Optional[torch.Tensor] = None
......@@ -261,6 +267,8 @@ class ForwardBatch:
# Has to be None when cuda graph is captured.
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The padding mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
......@@ -286,7 +294,7 @@ class ForwardBatch:
# For two-batch overlap
tbo_split_seq_index: Optional[int] = None
tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List["ForwardBatch"]] = None
tbo_children: Optional[List[ForwardBatch]] = None
@classmethod
def init_new(
......@@ -340,20 +348,38 @@ class ForwardBatch:
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)
# For DP attention
# For MLP sync
if batch.global_num_tokens is not None:
spec_num_draft_tokens = (
batch.spec_num_draft_tokens
if batch.spec_num_draft_tokens is not None
else 1
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
)
global_num_tokens = [
x * spec_num_draft_tokens for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
]
assert batch.global_num_tokens_for_logprob is not None
# process global_num_tokens and global_num_tokens_for_logprob
if batch.spec_info is not None:
if isinstance(batch.spec_info, EagleDraftInput):
global_num_tokens = [
x * batch.spec_info.num_tokens_per_batch
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.num_tokens_for_logprob_per_batch
for x in batch.global_num_tokens_for_logprob
]
else:
assert isinstance(batch.spec_info, EagleVerifyInput)
global_num_tokens = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens_for_logprob
]
else:
global_num_tokens = batch.global_num_tokens
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor(
......@@ -365,15 +391,8 @@ class ForwardBatch:
global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
sum_len = sum(global_num_tokens)
ret.gathered_buffer = torch.zeros(
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
TboForwardBatchPreparer.prepare(
ret, is_draft_worker=model_runner.is_draft_worker
)
......@@ -573,6 +592,158 @@ class ForwardBatch:
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
if value == 0:
return torch.cat(
[tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])],
dim=0,
)
else:
return torch.cat(
[
tensor,
tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value),
],
dim=0,
)
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
from sglang.srt.speculative.eagle_utils import EagleDraftInput
assert self.global_num_tokens_cpu is not None
assert self.global_num_tokens_for_logprob_cpu is not None
global_num_tokens = self.global_num_tokens_cpu
sync_group_size = len(global_num_tokens)
attn_tp_size = get_attention_tp_size()
for i in range(sync_group_size):
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob
global_num_tokens[i] = (
(global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len():
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
# where transferred tokens should be padded to the same length.
max_num_tokens = max(global_num_tokens)
global_num_tokens = [max_num_tokens] * sync_group_size
buffer_len = max_num_tokens * sync_group_size
else:
buffer_len = sum(global_num_tokens)
self.gathered_buffer = torch.zeros(
(buffer_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=model_runner.device,
)
bs = self.batch_size
if len(global_num_tokens) > 1:
num_tokens = global_num_tokens[get_attention_dp_rank()]
else:
num_tokens = global_num_tokens[0]
# padding
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
seq_len_fill_value = (
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
self.seq_lens = self._pad_tensor_to_size(
self.seq_lens, bs, value=seq_len_fill_value
)
if self.seq_lens_cpu is not None:
self.seq_lens_cpu = self._pad_tensor_to_size(
self.seq_lens_cpu, bs, value=seq_len_fill_value
)
self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens)
if self.encoder_lens is not None:
self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
self.global_num_tokens_cpu = global_num_tokens
self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor(
global_num_tokens
)
if self.mrope_positions is not None:
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
if self.extend_seq_lens is not None:
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
spec_info = self.spec_info
self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states
if spec_info.topk_p is not None:
spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs)
if spec_info.topk_index is not None:
spec_info.topk_index = self._pad_tensor_to_size(
spec_info.topk_index, bs
)
if spec_info.accept_length is not None:
spec_info.accept_length = self._pad_tensor_to_size(
spec_info.accept_length, bs
)
spec_info.hidden_states = self._pad_tensor_to_size(
spec_info.hidden_states, num_tokens
)
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
bs = self.batch_size
if self.spec_info is not None:
if self.forward_mode.is_decode(): # draft
num_tokens = self.hidden_states_backup.shape[0]
self.positions = self.positions[:num_tokens]
self.seq_lens = self.seq_lens[:bs]
self.req_pool_indices = self.req_pool_indices[:bs]
if self.seq_lens_cpu is not None:
self.seq_lens_cpu = self.seq_lens_cpu[:bs]
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
elif self.forward_mode.is_target_verify(): # verify
num_tokens = bs * self.spec_info.draft_token_num
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
elif self.forward_mode.is_draft_extend(): # draft extend
self.spec_info.accept_length = self.spec_info.accept_length[:bs]
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
if hasattr(self, "hidden_states_backup"):
self.spec_info.hidden_states = self.hidden_states_backup
if hasattr(self, "output_cache_loc_backup"):
self.out_cache_loc = self.output_cache_loc_backup
elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states[:bs]
elif self.forward_mode.is_extend():
num_tokens = self.seq_lens_sum
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
# Here we suppose the length of each chunk is equal
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# num_prefix_chunks = cdiv(1024, 256) = 4
......
......@@ -1464,9 +1464,13 @@ class ModelRunner:
tensor_parallel(self.model, device_mesh)
def forward_decode(
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors=None,
) -> LogitsProcessorOutput:
self.attn_backend.init_forward_metadata(forward_batch)
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)
# FIXME: add pp_proxy_tensors arg to all models
kwargs = {}
if self.support_pp:
......@@ -1578,8 +1582,18 @@ class ModelRunner:
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_decode():
ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
return ret, can_run_cuda_graph
# For MLP sync
if forward_batch.global_num_tokens_cpu is not None:
forward_batch.prepare_mlp_sync_batch(self)
if forward_batch.forward_mode.is_decode():
ret = self.forward_decode(
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_extend():
ret = self.forward_extend(
forward_batch,
......@@ -1597,6 +1611,9 @@ class ModelRunner:
else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
if forward_batch.global_num_tokens_cpu is not None:
forward_batch.post_forward_mlp_sync_batch(ret)
return ret, can_run_cuda_graph
def _preprocess_logits(
......
......@@ -550,9 +550,8 @@ class DeepseekV2MoE(nn.Module):
def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
shared_output = None
if is_non_idle_and_non_empty(forward_mode, hidden_states):
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
......
......@@ -43,10 +43,6 @@ from sglang.srt.layers.communicator import (
ScatterMode,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
......
......@@ -38,10 +38,6 @@ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
......@@ -193,8 +189,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
if is_non_idle_and_non_empty(forward_mode, hidden_states):
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
topk_weights, topk_idx, _ = self.topk(
......
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.layers.dp_attention import DPPaddingMode
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
......@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
)
if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
......@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
# Capture
try:
......@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
else max(forward_batch.global_num_tokens_cpu)
)
else:
cuda_graph_bs = forward_batch.batch_size
......@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
[num_tokens] * self.dp_size,
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
[num_tokens] * self.dp_size,
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
......@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
......@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
# Pad
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
max_batch_size = (
max_num_tokens // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
else max_num_tokens
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
index = bisect.bisect_left(self.capture_bs, max_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
......@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# TODO(ch-wan): support num_token_non_padded
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
# Attention backend
if bs != raw_bs:
......
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.layers.dp_attention import DPPaddingMode
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
......@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
)
if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
......@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
# Capture
try:
with model_capture_mode():
......@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
else max(forward_batch.global_num_tokens_cpu)
)
else:
cuda_graph_bs = forward_batch.seq_lens.numel()
......@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
[num_tokens] * self.dp_size,
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
[bs] * self.dp_size,
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
......@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
[bs],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput(
hidden_states=hidden_states,
......@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
......@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0]
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
max_batch_size = (
max_num_tokens // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
else max_num_tokens
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
index = bisect.bisect_left(self.capture_bs, max_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
......@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
# TODO(ch-wan): support num_token_non_padded
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.fill_(bs)
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
......
......@@ -71,9 +71,20 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
# Shape info for padding
num_tokens_per_batch: int = -1
num_tokens_for_logprob_per_batch: int = -1
# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch):
if batch.forward_mode.is_idle():
return
# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)
......@@ -95,7 +106,7 @@ class EagleDraftInput:
capture_hidden_mode: CaptureHiddenMode,
):
return cls(
verified_id=None,
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
......@@ -109,7 +120,10 @@ class EagleDraftInput:
batch: ScheduleBatch,
speculative_num_steps: int,
):
batch.forward_mode = ForwardMode.DRAFT_EXTEND
if batch.forward_mode.is_idle():
return
batch.input_ids = self.verified_id
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
......@@ -316,7 +330,7 @@ class EagleVerifyInput:
def verify(
self,
batch: ScheduleBatch,
logits_output: torch.Tensor,
logits_output: LogitsProcessorOutput,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
vocab_mask: Optional[torch.Tensor] = None, # For grammar
......@@ -599,13 +613,14 @@ class EagleVerifyInput:
batch.out_cache_loc = tgt_cache_loc
batch.seq_lens.add_(accept_length + 1)
draft_input = EagleDraftInput()
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
draft_input.verified_id = verified_id
draft_input.accept_length = accept_length
draft_input.accept_length_cpu = accept_length.tolist()
draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[accept_index],
verified_id=verified_id,
accept_length=accept_length,
accept_length_cpu=accept_length.tolist(),
seq_lens_for_draft_extend=batch.seq_lens,
req_pool_indices_for_draft_extend=batch.req_pool_indices,
)
return EagleVerifyOutput(
draft_input=draft_input,
......@@ -628,7 +643,6 @@ class EagleVerifyInput:
batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput()
if len(unfinished_accept_index) > 0:
unfinished_accept_index = torch.cat(unfinished_accept_index)
unfinished_index_device = torch.tensor(
......@@ -659,18 +673,26 @@ class EagleVerifyInput:
next_power_of_2(self.draft_token_num),
)
draft_input.hidden_states = batch.spec_info.hidden_states[
unfinished_accept_index
]
draft_input.verified_id = predict[unfinished_accept_index]
draft_input.accept_length_cpu = draft_input_accept_length_cpu
draft_input.accept_length = accept_length[unfinished_index_device]
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
unfinished_index_device
]
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
unfinished_index_device
]
draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[
unfinished_accept_index
],
verified_id=predict[unfinished_accept_index],
accept_length_cpu=draft_input_accept_length_cpu,
accept_length=accept_length[unfinished_index_device],
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
req_pool_indices_for_draft_extend=batch.req_pool_indices[
unfinished_index_device
],
)
else:
draft_input = EagleDraftInput.create_idle_input(
device=batch.device,
hidden_size=batch.model_config.hidden_size,
dtype=batch.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
return EagleVerifyOutput(
draft_input=draft_input,
......
......@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
def forward_batch_speculative_generation(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
......@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
self.verify(batch, spec_info)
)
if self.check_forward_draft_extend_after_decode(batch):
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(
batch,
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
if (
self.server_args.enable_dp_attention
or batch.spec_info.verified_id.shape[0] > 0
):
# decode is not finished
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verify_output.verified_id,
......@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
)
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
local_need_forward = (
batch.spec_info.verified_id is not None
and batch.spec_info.verified_id.shape[0] > 0
)
local_need_forward = batch.spec_info.verified_id.shape[0] > 0
if not self.server_args.enable_dp_attention:
return local_need_forward
......@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
def forward_target_extend(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int]:
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
"""Run the target extend.
Args:
......@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.spec_num_draft_tokens = 1
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch
)
......@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
self._draft_preprocess_decode(batch)
spec_info = batch.spec_info
assert isinstance(spec_info, EagleDraftInput)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
spec_info.num_tokens_per_batch = self.topk
spec_info.num_tokens_for_logprob_per_batch = self.topk
batch.return_hidden_states = False
# Get forward batch
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.topk
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
......@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
forward_batch
)
else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle():
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
......@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
spec_info = forward_batch.spec_info
assert isinstance(spec_info, EagleDraftInput)
out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = (
spec_info.topk_p,
......@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
spec_info.hidden_states = hidden_states
# Run forward
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True
)
self._detect_nan_if_needed(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
......@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
else ForwardMode.IDLE
)
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar:
......@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
self,
batch: ScheduleBatch,
hidden_states: torch.Tensor,
next_token_ids: List[int],
seq_lens_cpu: torch.Tensor,
next_token_ids: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
):
"""Run draft model extend. This API modifies the states of the batch.
......@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
num_tokens_per_batch=1,
num_tokens_for_logprob_per_batch=1,
)
batch.return_hidden_states = False
batch.spec_info.prepare_for_extend(batch)
......@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = 1
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker):
self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
assert isinstance(batch.spec_info, EagleDraftInput)
# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata
if batch.spec_info.verified_id is not None:
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
else:
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
else self.model_config.hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
else self.model_config.hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
batch.spec_info.num_tokens_for_logprob_per_batch = 1
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
batch.forward_mode = (
ForwardMode.DRAFT_EXTEND
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
......@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
)
forward_batch.spec_info.hidden_states = logits_output.hidden_states
else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle():
self.draft_model_runner.attn_backend.init_forward_metadata(
forward_batch
)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True
)
self.capture_for_decode(logits_output, forward_batch.spec_info)
......
......@@ -545,6 +545,7 @@ class TboForwardBatchPreparer:
tbo_children=None,
global_num_tokens_gpu=None,
global_num_tokens_cpu=None,
dp_padding_mode=None,
gathered_buffer=gathered_buffer,
global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None,
......
......@@ -35,7 +35,7 @@ class TestPureDP(CustomTestCase):
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"128",
"512",
"--mem-fraction-static",
"0.5",
],
......@@ -81,7 +81,7 @@ class TestHybridDPTP(CustomTestCase):
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"128",
"256",
],
)
......@@ -170,7 +170,7 @@ class TestNoGatherdBuffer(CustomTestCase):
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"128",
"512",
],
)
......@@ -217,7 +217,7 @@ class TestTBO(CustomTestCase):
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"128",
"512",
],
)
......@@ -273,7 +273,7 @@ class TestMTP(CustomTestCase):
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"64",
],
)
......@@ -343,7 +343,7 @@ class TestMTPWithTBO(CustomTestCase):
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"32",
"128",
],
)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment