Unverified Commit 10d60cd4 authored by u4lr451's avatar u4lr451 Committed by GitHub
Browse files

feat: mtp support dp-attention (#6081)


Co-authored-by: default avataraustindeng <austindeng@tencent.com>
Co-authored-by: default avatartianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: default avatarQiaolin Yu <liin1211@outlook.com>
Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent 8a10c4c3
......@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
if kv_indices_buf is None:
......@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)
......
......@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
"""Init the metadata for a forward pass."""
raise NotImplementedError()
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Init the global shared states for cuda graph."""
raise NotImplementedError()
......
......@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
......
......@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend.
Args:
......@@ -1999,9 +1999,9 @@ class FlashAttentionMultiStepBackend:
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs)
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,
......
......@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
(max_num_tokens * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
......@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
......@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
......@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
......
......@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
......@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
......@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
......
......@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
......@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
......
......@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child)
def init_cuda_graph_state(self, max_bs: int):
self.primary.init_cuda_graph_state(max_bs=max_bs)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs)
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,
......
......@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
num_kv_splits = None
attn_logits = None
attn_lse = None
elif forward_batch.forward_mode.is_draft_extend():
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
......@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_attn_logits = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_attn_lse = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits),
(max_num_tokens, self.num_head, self.max_kv_splits),
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
)
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
......@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)
......@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
if self.sliding_window_size is not None and self.sliding_window_size > 0:
if kv_indices_buf is None:
self.cuda_graph_window_kv_indices = torch.zeros(
(max_bs * self.sliding_window_size),
(max_num_tokens * self.sliding_window_size),
dtype=torch.int32,
device=self.device,
)
......@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
self.cuda_graph_window_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
(max_num_tokens,),
self.max_kv_splits,
dtype=torch.int32,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
......@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
......@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
......
......@@ -238,6 +238,10 @@ def _dp_gather(
assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed"
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
......@@ -288,6 +292,10 @@ def dp_scatter(
assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed"
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
......
......@@ -862,6 +862,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
is_extend_in_batch: bool = False
tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None
......@@ -1760,11 +1761,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
)
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
f"#req={(len(self.reqs))})"
)
......@@ -1833,6 +1838,7 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None
# Overlap event
launch_done: Optional[threading.Event] = None
......
......@@ -1350,6 +1350,29 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
"""Coordinate the DP attention batch."""
local_info = torch.tensor(
[
(new_batch is not None),
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 1),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
any_new_batch = any(
global_info[:, 0, 0].tolist()
) # Any DP worker has forward batch
return any_new_batch
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
......@@ -1383,7 +1406,14 @@ class Scheduler(
self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# TODO(ch-wan): minor refactor is needed here to improve readability
any_new_batch = (
self.server_args.enable_dp_attention
and not self.spec_algorithm.is_none()
and self.coordinate_spec_dp_attn_batch(new_batch)
)
if new_batch is not None or any_new_batch:
# Run prefill first if possible
ret = new_batch
else:
......@@ -1732,8 +1762,6 @@ class Scheduler(
num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
......@@ -1809,6 +1837,7 @@ class Scheduler(
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
local_batch.is_extend_in_batch = any(is_extend_in_batch)
local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode
......@@ -1816,6 +1845,7 @@ class Scheduler(
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return local_batch, any(is_extend_in_batch)
def get_idle_batch(self):
......
......@@ -242,13 +242,13 @@ class CudaGraphRunner:
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
if global_server_args_dict["attention_backend"] == "flashmla":
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
else:
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.model_runner.attn_backend.init_cuda_graph_state(
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full(
......@@ -323,12 +323,15 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
total_global_tokens in self.graphs
total_batch_size in self.graphs
if self.disable_padding
else total_global_tokens <= self.max_bs
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = (
......@@ -460,7 +463,7 @@ class CudaGraphRunner:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < bs % self.dp_size)
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
......@@ -605,9 +608,12 @@ class CudaGraphRunner:
# Pad
if self.enable_dp_attention or self.enable_sp_layernorm:
index = bisect.bisect_left(
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
......@@ -650,13 +656,13 @@ class CudaGraphRunner:
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs,
self.req_pool_indices,
self.seq_lens,
self.req_pool_indices[:bs],
self.seq_lens[:bs],
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens,
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
forward_batch.forward_mode,
forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
seq_lens_cpu=self.seq_lens_cpu[:bs],
)
# Store fields
......
......@@ -320,17 +320,30 @@ class ForwardBatch:
# For DP attention
if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
spec_num_draft_tokens = (
batch.spec_num_draft_tokens
if batch.spec_num_draft_tokens is not None
else 1
)
global_num_tokens = [
x * spec_num_draft_tokens for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
]
ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64
global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64
global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens)
sum_len = sum(global_num_tokens)
ret.gathered_buffer = torch.zeros(
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
......
......@@ -163,6 +163,7 @@ class ModelRunner:
logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank
self.tp_size = tp_size
self.dp_size = server_args.dp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.dist_port = nccl_port
......@@ -196,6 +197,7 @@ class ModelRunner:
| {
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
)
......
......@@ -22,7 +22,6 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
......@@ -77,6 +76,7 @@ class DeepseekModelNextN(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
......@@ -90,6 +90,7 @@ class DeepseekModelNextN(nn.Module):
else:
hidden_states = input_embeds
if hidden_states.shape[0] > 0:
hidden_states = self.eh_proj(
torch.cat(
(
......@@ -127,21 +128,12 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self.model = DeepseekModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
prefix=add_prefix("model.shared_head.head", prefix),
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
......
......@@ -1399,7 +1399,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.layer_id = layer_id
self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA(
config=config,
hidden_size=self.hidden_size,
......@@ -1500,6 +1502,11 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch
)
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
hidden_states = hidden_states.clone()
return hidden_states, residual
def op_comm_prepare_attn(
......
......@@ -38,6 +38,10 @@ class EAGLEDraftCudaGraphRunner:
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size
self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
......@@ -53,7 +57,9 @@ class EAGLEDraftCudaGraphRunner:
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
self.model_runner.draft_attn_backend.init_cuda_graph_state(
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0
].get_cuda_graph_seq_len_fill_value()
......@@ -78,10 +84,26 @@ class EAGLEDraftCudaGraphRunner:
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture
try:
with model_capture_mode():
......@@ -92,6 +114,21 @@ class EAGLEDraftCudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
......@@ -116,8 +153,40 @@ class EAGLEDraftCudaGraphRunner:
topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs]
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput(
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
# Forward batch
......@@ -133,11 +202,14 @@ class EAGLEDraftCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
),
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
)
# Attention backend
......@@ -147,6 +219,9 @@ class EAGLEDraftCudaGraphRunner:
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
......@@ -184,6 +259,14 @@ class EAGLEDraftCudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.enable_dp_attention or self.enable_sp_layernorm:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
......@@ -203,6 +286,13 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
# Attention backend
if bs != raw_bs:
forward_batch.batch_size = bs
......@@ -210,7 +300,9 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
# Special handle for seq_len_cpu used when flashinfer mla is used
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
......
......@@ -35,6 +35,8 @@ class EAGLEDraftExtendCudaGraphRunner:
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
......@@ -51,7 +53,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
self.max_num_token
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = (
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
......@@ -90,6 +92,21 @@ class EAGLEDraftExtendCudaGraphRunner:
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture
try:
with model_capture_mode():
......@@ -100,6 +117,21 @@ class EAGLEDraftExtendCudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
return is_bs_supported
else:
batch_size = forward_batch.seq_lens.numel()
is_bs_supported = (
......@@ -128,6 +160,35 @@ class EAGLEDraftExtendCudaGraphRunner:
positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens]
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput(
hidden_states=hidden_states,
accept_length=accept_length,
......@@ -147,6 +208,9 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST,
......@@ -167,6 +231,9 @@ class EAGLEDraftExtendCudaGraphRunner:
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
......@@ -203,24 +270,42 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0]
if self.enable_dp_attention or self.enable_sp_layernorm:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.accept_length.fill_(1)
self.extend_seq_lens.fill_(1)
# Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
if forward_batch.extend_seq_lens is not None:
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
self.positions[:num_tokens].copy_(forward_batch.positions)
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
if forward_batch.spec_info.accept_length is not None:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
......
......@@ -25,6 +25,8 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
logger = logging.getLogger(__name__)
if is_cuda():
from sgl_kernel import (
fast_topk,
......@@ -69,6 +71,8 @@ class EagleDraftInput:
kv_indices: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch):
if batch.forward_mode.is_idle():
return
# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)
......@@ -80,6 +84,24 @@ class EagleDraftInput:
)
pt += extend_len
@classmethod
def create_idle_input(
cls,
device: torch.device,
hidden_size: int,
topk: int,
capture_hidden_mode: CaptureHiddenMode,
):
return cls(
verified_id=None,
hidden_states=torch.empty(
(0, hidden_size), device=device, dtype=torch.float32
),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
)
def prepare_extend_after_decode(
self,
batch: ScheduleBatch,
......@@ -193,7 +215,35 @@ class EagleVerifyInput:
seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None
@classmethod
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
return cls(
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
retrive_index=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_next_token=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_next_sibling=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_cum_len=None,
topk=topk,
draft_token_num=num_verify_tokens,
spec_steps=spec_steps,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=0,
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
)
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle():
return
batch.input_ids = self.draft_token
if page_size == 1:
......@@ -279,6 +329,25 @@ class EagleVerifyInput:
tokens. I.e., logits_output.next_token_logits only contains
accepted token logits.
"""
if batch.forward_mode.is_idle():
return EagleVerifyOutput(
draft_input=EagleDraftInput.create_idle_input(
device=batch.device,
hidden_size=batch.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
),
logits_output=logits_output,
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
accept_length_per_req_cpu=[],
accepted_indices=torch.full(
(0, self.spec_steps + 1),
-1,
dtype=torch.int32,
device=batch.device,
),
)
bs = self.retrive_index.shape[0]
candidates = self.draft_token.reshape(bs, self.draft_token_num)
sampling_info = batch.sampling_info
......@@ -992,6 +1061,7 @@ def select_top_k_tokens(
topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
if hidden_states.shape[0] > 0:
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment