Unverified Commit dc0705a5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify prepare_extend_after_decode (#6987)

parent a968c888
...@@ -1636,7 +1636,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1636,7 +1636,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.spec_info: if self.spec_info:
self.spec_info.merge_batch(other.spec_info) self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self) -> ModelWorkerBatch: def get_model_worker_batch(
self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle(): if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else: else:
...@@ -1646,16 +1648,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1646,16 +1648,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Create seq_lens_cpu when needed # Create seq_lens_cpu when needed
if ( if (
( global_server_args_dict["attention_backend"] == "fa3"
or (
global_server_args_dict["use_mla_backend"] global_server_args_dict["use_mla_backend"]
and global_server_args_dict["attention_backend"] == "flashinfer" and global_server_args_dict["attention_backend"] == "flashinfer"
) )
or global_server_args_dict["attention_backend"] == "flashmla" or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "fa3"
or global_server_args_dict["attention_backend"] == "cutlass_mla" or global_server_args_dict["attention_backend"] == "cutlass_mla"
or global_server_args_dict["enable_two_batch_overlap"] or global_server_args_dict["enable_two_batch_overlap"]
): ):
seq_lens_cpu = self.seq_lens.cpu() seq_lens_cpu = (
seq_lens_cpu_cache
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
)
else: else:
seq_lens_cpu = None seq_lens_cpu = None
......
...@@ -1575,10 +1575,9 @@ class Scheduler( ...@@ -1575,10 +1575,9 @@ class Scheduler(
num_accepted_tokens, num_accepted_tokens,
can_run_cuda_graph, can_run_cuda_graph,
) = self.draft_worker.forward_batch_speculative_generation(batch) ) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += ( bs = batch.batch_size()
num_accepted_tokens + batch.batch_size() self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
) self.spec_num_total_forward_ct += bs
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens self.num_generated_tokens += num_accepted_tokens
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
......
...@@ -56,6 +56,16 @@ def get_is_capture_mode(): ...@@ -56,6 +56,16 @@ def get_is_capture_mode():
return is_capture_mode return is_capture_mode
@contextmanager
def model_capture_mode():
global is_capture_mode
is_capture_mode = True
yield
is_capture_mode = False
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values(): for sub in model._modules.values():
if isinstance(sub, CustomOp): if isinstance(sub, CustomOp):
...@@ -291,22 +301,13 @@ class CudaGraphRunner: ...@@ -291,22 +301,13 @@ class CudaGraphRunner:
# Capture # Capture
try: try:
with self.model_capture_mode(): with model_capture_mode():
self.capture() self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
) )
@contextmanager
def model_capture_mode(self):
global is_capture_mode
is_capture_mode = True
yield
is_capture_mode = False
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm: if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu) total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
...@@ -650,6 +651,8 @@ class CudaGraphRunner: ...@@ -650,6 +651,8 @@ class CudaGraphRunner:
topk=self.model_runner.server_args.speculative_eagle_topk, topk=self.model_runner.server_args.speculative_eagle_topk,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL, capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=None,
seq_lens_cpu=None,
) )
return spec_info return spec_info
......
...@@ -1013,13 +1013,13 @@ class ServerArgs: ...@@ -1013,13 +1013,13 @@ class ServerArgs:
type=str, type=str,
choices=[ choices=[
"aiter", "aiter",
"flashinfer", "cutlass_mla",
"triton",
"torch_native",
"fa3", "fa3",
"flashinfer",
"flashmla", "flashmla",
"cutlass_mla",
"intel_amx", "intel_amx",
"torch_native",
"triton",
], ],
default=ServerArgs.attention_backend, default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.", help="Choose the kernels for attention layers.",
......
...@@ -10,6 +10,7 @@ from sglang.srt.model_executor.cuda_graph_runner import ( ...@@ -10,6 +10,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
CudaGraphRunner, CudaGraphRunner,
get_batch_sizes_to_capture, get_batch_sizes_to_capture,
get_global_graph_memory_pool, get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool, set_global_graph_memory_pool,
set_torch_compile_config, set_torch_compile_config,
) )
...@@ -80,7 +81,8 @@ class EAGLEDraftCudaGraphRunner: ...@@ -80,7 +81,8 @@ class EAGLEDraftCudaGraphRunner:
# Capture # Capture
try: try:
self.capture() with model_capture_mode():
self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
......
...@@ -11,6 +11,7 @@ from sglang.srt.model_executor.cuda_graph_runner import ( ...@@ -11,6 +11,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
LogitsProcessorOutput, LogitsProcessorOutput,
get_batch_sizes_to_capture, get_batch_sizes_to_capture,
get_global_graph_memory_pool, get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool, set_global_graph_memory_pool,
set_torch_compile_config, set_torch_compile_config,
) )
...@@ -19,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -19,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
...@@ -37,6 +38,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -37,6 +38,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.topk = model_runner.server_args.speculative_eagle_topk
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.padded_static_len = -1 self.padded_static_len = -1
...@@ -87,7 +89,8 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -87,7 +89,8 @@ class EAGLEDraftExtendCudaGraphRunner:
# Capture # Capture
try: try:
self.capture() with model_capture_mode():
self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
...@@ -170,6 +173,8 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -170,6 +173,8 @@ class EAGLEDraftExtendCudaGraphRunner:
forward_batch.positions, forward_batch.positions,
forward_batch, forward_batch,
) )
probs = torch.softmax(ret.next_token_logits, dim=-1)
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)
forward_batch.out_cache_loc = output_cache_loc_backup forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup forward_batch.spec_info.hidden_states = hidden_states_backup
...@@ -198,7 +203,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -198,7 +203,7 @@ class EAGLEDraftExtendCudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(1) self.seq_lens.fill_(1)
self.accept_length.fill_(1) self.accept_length.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
...@@ -238,8 +243,11 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -238,8 +243,11 @@ class EAGLEDraftExtendCudaGraphRunner:
out = self.output_buffers[bs] out = self.output_buffers[bs]
if bs != raw_bs: if bs != raw_bs:
forward_batch.spec_info.accept_length = self.accept_length[:raw_bs] forward_batch.spec_info.accept_length = self.accept_length[:raw_bs]
out_copy = out
out = LogitsProcessorOutput( out = LogitsProcessorOutput(
next_token_logits=out.next_token_logits[:raw_bs], next_token_logits=out.next_token_logits[:raw_bs],
hidden_states=out.hidden_states[:raw_bs], hidden_states=out.hidden_states[:raw_bs],
) )
out.topk_p = out_copy.topk_p[:raw_bs]
out.topk_index = out_copy.topk_index[:raw_bs]
return out return out
...@@ -22,8 +22,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -22,8 +22,7 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
if is_cuda(): if is_cuda():
...@@ -86,78 +85,29 @@ class EagleDraftInput: ...@@ -86,78 +85,29 @@ class EagleDraftInput:
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
speculative_num_steps: int, speculative_num_steps: int,
context_length: int,
pad_input: bool = False,
): ):
accept_length_cpu = batch.spec_info.accept_length_cpu batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.input_ids = self.verified_id
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens) batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist() batch.return_logprob = False
self.positions = torch.empty_like(self.verified_id, dtype=torch.long) self.capture_hidden_mode = CaptureHiddenMode.LAST
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
self.accept_length.add_(1) self.accept_length.add_(1)
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
create_extend_spec_info[(self.accept_length.numel(),)]( create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
self.verified_id, batch.input_ids,
batch.seq_lens, batch.seq_lens,
self.accept_length, self.accept_length,
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
self.positions, self.positions,
new_verified_id, self.verified_id,
next_power_of_2(speculative_num_steps + 1), next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
) )
batch.seq_lens_sum = sum(seq_lens_cpu)
batch.input_ids = self.verified_id
self.verified_id = new_verified_id
if not pad_input:
return
batch_size = sum(not req.finished() for req in batch.reqs)
# Total constant input length after padding
static_len = speculative_num_steps + 1
# Total size after padding
padded_input_size = batch_size * static_len
padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(padded_len, device=self.positions.device)
new_positions = torch.cat([self.positions, position_padding])
# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)
# allocate KV cache location for the padded tokens
padded_cache_loc = torch.zeros(
padded_len,
dtype=batch.out_cache_loc.dtype,
device=batch.out_cache_loc.device,
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
batch.input_ids = new_input_ids
self.hidden_states = new_hidden_states
self.positions = new_positions
batch.out_cache_loc = new_out_cache_loc
def generate_attn_arg_prefill( def generate_attn_arg_prefill(
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
...@@ -173,8 +123,9 @@ class EagleDraftInput: ...@@ -173,8 +123,9 @@ class EagleDraftInput:
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync. kv_indices = torch.empty(
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
req_to_token, req_to_token,
...@@ -238,54 +189,10 @@ class EagleVerifyInput: ...@@ -238,54 +189,10 @@ class EagleVerifyInput:
topk: int topk: int
draft_token_num: int draft_token_num: int
capture_hidden_mode: CaptureHiddenMode capture_hidden_mode: CaptureHiddenMode
seq_lens_sum: int
seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None grammar: BaseGrammarObject = None
@classmethod
def create(
cls,
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
parents_list: List[torch.Tensor],
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_tokens: int,
):
(
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
) = build_tree_kernel_efficient(
verified_id,
score_list,
token_list,
parents_list,
seq_lens,
seq_lens_sum,
topk,
spec_steps,
num_verify_tokens,
)
return cls(
draft_token=draft_tokens,
custom_mask=tree_mask,
positions=position,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
retrive_cum_len=None,
spec_steps=spec_steps,
topk=topk,
draft_token_num=num_verify_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
)
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
...@@ -614,26 +521,28 @@ class EagleVerifyInput: ...@@ -614,26 +521,28 @@ class EagleVerifyInput:
@triton.jit @triton.jit
def create_extend_spec_info( def create_extend_after_decode_spec_info(
verified_id, verified_id,
seq_len, seq_lens,
accept_len, accept_lens,
accept_len_cum,
positions, positions,
new_verified_id, new_verified_id,
accept_len_upper: tl.constexpr, bs_upper: tl.constexpr,
): ):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) offsets = tl.arange(0, bs_upper)
seq_length = tl.load(seq_len + pid) seq_length = tl.load(seq_lens + pid)
accept_length = tl.load(accept_len + pid) accept_length = tl.load(accept_lens + pid)
positions_ptr = positions + offset
data = tl.arange(0, accept_len_upper) accept_len_cumsum = tl.sum(
mask = data < accept_length tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
tl.store(positions_ptr + data, seq_length - accept_length + data, mask) )
positions_ptr = positions + accept_len_cumsum
offset = tl.load(accept_len_cum + pid) - 1 mask = offsets < accept_length
verified_id_data = tl.load(verified_id + offset) tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
accept_len_cumsum += accept_length - 1
verified_id_data = tl.load(verified_id + accept_len_cumsum)
tl.store(new_verified_id + pid, verified_id_data) tl.store(new_verified_id + pid, verified_id_data)
...@@ -654,8 +563,8 @@ def assign_req_to_token_pool( ...@@ -654,8 +563,8 @@ def assign_req_to_token_pool(
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
length_offset = tl.arange(0, bs_upper) length_offset = tl.arange(0, bs_upper)
start = tl.load(start_offset + length_offset, mask=length_offset < pid) start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
end = tl.load(end_offset + length_offset, mask=length_offset < pid) end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
out_offset = tl.sum(end - start, axis=0) out_offset = tl.sum(end - start, axis=0)
out_cache_ptr = out_cache_loc + out_offset out_cache_ptr = out_cache_loc + out_offset
...@@ -736,7 +645,7 @@ def generate_draft_decode_kv_indices( ...@@ -736,7 +645,7 @@ def generate_draft_decode_kv_indices(
iters += 1 iters += 1
load_offset = tl.arange(0, bs_upper) load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
seq_len = tl.load(paged_kernel_lens + bid) seq_len = tl.load(paged_kernel_lens + bid)
cum_seq_len = tl.sum(seq_lens) cum_seq_len = tl.sum(seq_lens)
...@@ -765,7 +674,7 @@ def generate_draft_decode_kv_indices( ...@@ -765,7 +674,7 @@ def generate_draft_decode_kv_indices(
zid = bid * topk + topk_id zid = bid * topk + topk_id
if zid == 0: if zid == 0:
zid = num_seqs * topk zid = num_seqs * topk
positions = tl.load(positions + bs_offset, mask=bs_offset < zid) positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
base = tl.sum(positions) base = tl.sum(positions)
tl.store(kv_indptr + zid, base + zid * iters) tl.store(kv_indptr + zid, base + zid * iters)
...@@ -783,7 +692,9 @@ def align_evict_mask_to_page_size( ...@@ -783,7 +692,9 @@ def align_evict_mask_to_page_size(
bid = tl.program_id(axis=0) bid = tl.program_id(axis=0)
seq_len = tl.load(seq_lens + bid) seq_len = tl.load(seq_lens + bid)
io_mask = t_range < num_draft_tokens io_mask = t_range < num_draft_tokens
mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask) mask_row = tl.load(
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
)
num_trues = tl.sum(mask_row) num_trues = tl.sum(mask_row)
num_false = num_draft_tokens - num_trues num_false = num_draft_tokens - num_trues
......
...@@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner, EAGLEDraftCudaGraphRunner,
) )
...@@ -69,7 +70,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -69,7 +70,6 @@ class EAGLEWorker(TpModelWorker):
self.server_args = server_args self.server_args = server_args
self.topk = server_args.speculative_eagle_topk self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps self.speculative_num_steps = server_args.speculative_num_steps
self.padded_static_len = self.speculative_num_steps + 1
self.enable_nan_detection = server_args.enable_nan_detection self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.device = server_args.device self.device = server_args.device
...@@ -78,6 +78,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -78,6 +78,7 @@ class EAGLEWorker(TpModelWorker):
self.speculative_algorithm = SpeculativeAlgorithm.from_string( self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.padded_static_len = -1
# Override context length with target model's context length # Override context length with target model's context length
server_args.context_length = target_worker.model_runner.model_config.context_len server_args.context_length = target_worker.model_runner.model_config.context_len
...@@ -184,7 +185,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -184,7 +185,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner, self.draft_model_runner,
skip_prefill=False, skip_prefill=False,
) )
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton": elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import ( from sglang.srt.layers.attention.triton_backend import (
...@@ -201,7 +201,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -201,7 +201,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner, self.draft_model_runner,
skip_prefill=False, skip_prefill=False,
) )
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3": elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import ( from sglang.srt.layers.attention.flashattention_backend import (
...@@ -218,7 +217,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -218,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner, self.draft_model_runner,
skip_prefill=False, skip_prefill=False,
) )
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla": elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import ( from sglang.srt.layers.attention.flashmla_backend import (
...@@ -231,7 +229,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -231,7 +229,6 @@ class EAGLEWorker(TpModelWorker):
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False self.has_prefill_wrapper_verify = False
else: else:
raise ValueError( raise ValueError(
...@@ -319,10 +316,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -319,10 +316,12 @@ class EAGLEWorker(TpModelWorker):
return logits_output, next_token_ids, model_worker_batch.bid, 0, False return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else: else:
logits_output, next_token_ids, bid = self.forward_target_extend(batch) logits_output, next_token_ids, bid, seq_lens_cpu = (
self.forward_target_extend(batch)
)
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend( self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
) )
return logits_output, next_token_ids, bid, 0, False return logits_output, next_token_ids, bid, 0, False
...@@ -346,7 +345,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -346,7 +345,12 @@ class EAGLEWorker(TpModelWorker):
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
return logits_output, next_token_ids, model_worker_batch.bid return (
logits_output,
next_token_ids,
model_worker_batch.bid,
model_worker_batch.seq_lens_cpu,
)
def draft(self, batch: ScheduleBatch): def draft(self, batch: ScheduleBatch):
# Parse args # Parse args
...@@ -452,7 +456,14 @@ class EAGLEWorker(TpModelWorker): ...@@ -452,7 +456,14 @@ class EAGLEWorker(TpModelWorker):
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup) self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
ret = EagleVerifyInput.create( (
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
) = build_tree_kernel_efficient(
spec_info.verified_id, spec_info.verified_id,
score_list, score_list,
token_list, token_list,
...@@ -463,7 +474,22 @@ class EAGLEWorker(TpModelWorker): ...@@ -463,7 +474,22 @@ class EAGLEWorker(TpModelWorker):
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens, self.server_args.speculative_num_draft_tokens,
) )
return ret
return EagleVerifyInput(
draft_token=draft_tokens,
custom_mask=tree_mask,
positions=position,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
retrive_cum_len=None,
spec_steps=self.speculative_num_steps,
topk=self.topk,
draft_token_num=self.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=forward_batch.seq_lens_sum,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
def draft_forward(self, forward_batch: ForwardBatch): def draft_forward(self, forward_batch: ForwardBatch):
# Parse args # Parse args
...@@ -523,7 +549,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -523,7 +549,9 @@ class EAGLEWorker(TpModelWorker):
spec_info.prepare_for_verify(batch, self.page_size) spec_info.prepare_for_verify(batch, self.page_size)
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
if batch.has_grammar: if batch.has_grammar:
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu() retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
...@@ -650,6 +678,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -650,6 +678,7 @@ class EAGLEWorker(TpModelWorker):
batch: ScheduleBatch, batch: ScheduleBatch,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
next_token_ids: List[int], next_token_ids: List[int],
seq_lens_cpu: torch.Tensor,
): ):
"""Run draft model extend. This API modifies the states of the batch. """Run draft model extend. This API modifies the states of the batch.
...@@ -664,7 +693,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -664,7 +693,9 @@ class EAGLEWorker(TpModelWorker):
) )
batch.spec_info.prepare_for_extend(batch) batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -683,19 +714,18 @@ class EAGLEWorker(TpModelWorker): ...@@ -683,19 +714,18 @@ class EAGLEWorker(TpModelWorker):
return_logprob_backup = batch.return_logprob return_logprob_backup = batch.return_logprob
# Prepare metadata # Prepare metadata
batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode( batch.spec_info.prepare_extend_after_decode(
batch, batch,
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.context_length,
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
) )
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_logprob = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
if forward_batch.seq_lens_cpu is not None:
forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item()
else:
forward_batch.seq_lens_sum = batch.seq_lens.sum().item()
# Run # Run
can_cuda_graph = ( can_cuda_graph = (
...@@ -706,14 +736,19 @@ class EAGLEWorker(TpModelWorker): ...@@ -706,14 +736,19 @@ class EAGLEWorker(TpModelWorker):
logits_output = self.cuda_graph_runner_for_draft_extend.replay( logits_output = self.cuda_graph_runner_for_draft_extend.replay(
forward_batch forward_batch
) )
forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = (
logits_output.topk_p,
logits_output.topk_index,
)
forward_batch.spec_info.hidden_states = logits_output.hidden_states
else: else:
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch) self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
logits_output = self.draft_model_runner.model.forward( logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
self.capture_for_decode(logits_output, forward_batch.spec_info)
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
self.capture_for_decode(logits_output, forward_batch.spec_info)
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
......
...@@ -87,7 +87,7 @@ class TestDeepseekV3MTP(CustomTestCase): ...@@ -87,7 +87,7 @@ class TestDeepseekV3MTP(CustomTestCase):
"--speculative-num-steps", "--speculative-num-steps",
"3", "3",
"--speculative-eagle-topk", "--speculative-eagle-topk",
"2", "1",
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"4", "4",
] ]
...@@ -155,7 +155,7 @@ class TestDeepseekV3MTP(CustomTestCase): ...@@ -155,7 +155,7 @@ class TestDeepseekV3MTP(CustomTestCase):
if is_in_amd_ci(): if is_in_amd_ci():
self.assertGreater(speed, 15) self.assertGreater(speed, 15)
else: else:
self.assertGreater(speed, 105) self.assertGreater(speed, 130)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment