Unverified Commit fff10809 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "[EAGLE] Refactor code for page size > 1 & more simplifications" (#7210)

parent 5f1ab327
...@@ -1049,13 +1049,14 @@ class FlashInferMultiStepDraftBackend: ...@@ -1049,13 +1049,14 @@ class FlashInferMultiStepDraftBackend:
kv_indices_buffer, kv_indices_buffer,
self.kv_indptr, self.kv_indptr,
forward_batch.positions, forward_batch.positions,
num_seqs,
self.topk,
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
next_power_of_2(num_seqs), next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs), next_power_of_2(bs),
self.page_size,
) )
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
......
...@@ -789,7 +789,6 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -789,7 +789,6 @@ class FlashInferMLAMultiStepDraftBackend:
# Cached variables for generate_draft_decode_kv_indices # Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
def common_template( def common_template(
self, self,
...@@ -810,13 +809,14 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -810,13 +809,14 @@ class FlashInferMLAMultiStepDraftBackend:
kv_indices_buffer, kv_indices_buffer,
self.kv_indptr, self.kv_indptr,
forward_batch.positions, forward_batch.positions,
num_seqs,
self.topk,
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
next_power_of_2(num_seqs), next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs), next_power_of_2(bs),
self.page_size,
) )
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
......
...@@ -784,13 +784,14 @@ class TritonMultiStepDraftBackend: ...@@ -784,13 +784,14 @@ class TritonMultiStepDraftBackend:
kv_indices_buffer, kv_indices_buffer,
self.kv_indptr, self.kv_indptr,
forward_batch.positions, forward_batch.positions,
num_seqs,
self.topk,
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
next_power_of_2(num_seqs), next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs), next_power_of_2(bs),
self.page_size,
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
......
...@@ -294,19 +294,6 @@ class MHATokenToKVPool(KVCache): ...@@ -294,19 +294,6 @@ class MHATokenToKVPool(KVCache):
for _ in range(self.layer_num) for _ in range(self.layer_num)
] ]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
dtype=torch.uint64,
device=self.device,
)
self.data_strides = torch.tensor(
[
np.prod(x.shape[1:]) * x.dtype.itemsize
for x in self.k_buffer + self.v_buffer
],
device=self.device,
)
def _clear_buffers(self): def _clear_buffers(self):
del self.k_buffer del self.k_buffer
del self.v_buffer del self.v_buffer
...@@ -464,16 +451,6 @@ class MHATokenToKVPool(KVCache): ...@@ -464,16 +451,6 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[layer_id - self.start_layer][loc] = cache_k self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
self.data_ptrs,
self.data_strides,
tgt_loc,
src_loc,
len(tgt_loc),
next_power_of_2(len(tgt_loc)),
)
@triton.jit @triton.jit
def set_mla_kv_buffer_kernel( def set_mla_kv_buffer_kernel(
...@@ -764,41 +741,3 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -764,41 +741,3 @@ class DoubleSparseTokenToKVPool(KVCache):
def transfer_per_layer(self, indices, flat_data, layer_id): def transfer_per_layer(self, indices, flat_data, layer_id):
pass pass
@triton.jit
def copy_all_layer_kv_cache(
data_ptrs,
strides,
tgt_loc_ptr,
src_loc_ptr,
num_locs,
num_locs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
bid = tl.program_id(0)
stride = tl.load(strides + bid)
data_ptr = tl.load(data_ptrs + bid)
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
num_locs_offset = tl.arange(0, num_locs_upper)
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
# because this copy is an inplace operation.
num_loop = tl.cdiv(stride, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
value = tl.load(
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
)
tl.store(
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
value,
mask=mask,
)
...@@ -67,6 +67,8 @@ class EagleDraftInput: ...@@ -67,6 +67,8 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None kv_indices: torch.Tensor = None
all_padding_lens: Optional[torch.Tensor] = None
def prepare_for_extend(self, batch: ScheduleBatch): def prepare_for_extend(self, batch: ScheduleBatch):
# Prefill only generate 1 token. # Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens) assert len(self.verified_id) == len(batch.seq_lens)
...@@ -91,7 +93,6 @@ class EagleDraftInput: ...@@ -91,7 +93,6 @@ class EagleDraftInput:
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
batch.return_logprob = False batch.return_logprob = False
batch.return_hidden_states = False
self.capture_hidden_mode = CaptureHiddenMode.LAST self.capture_hidden_mode = CaptureHiddenMode.LAST
self.accept_length.add_(1) self.accept_length.add_(1)
...@@ -115,8 +116,10 @@ class EagleDraftInput: ...@@ -115,8 +116,10 @@ class EagleDraftInput:
req_to_token: torch.Tensor, req_to_token: torch.Tensor,
): ):
bs = self.accept_length.numel() bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
...@@ -136,6 +139,7 @@ class EagleDraftInput: ...@@ -136,6 +139,7 @@ class EagleDraftInput:
kv_indices, kv_indices,
req_to_token.size(1), req_to_token.size(1),
) )
return kv_indices, cum_kv_seq_len, qo_indptr, None return kv_indices, cum_kv_seq_len, qo_indptr, None
def filter_batch(self, new_indices: torch.Tensor): def filter_batch(self, new_indices: torch.Tensor):
...@@ -266,7 +270,7 @@ class EagleVerifyInput: ...@@ -266,7 +270,7 @@ class EagleVerifyInput:
logits_output: torch.Tensor, logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int, page_size: int,
vocab_mask: Optional[torch.Tensor] = None, # For grammar vocab_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Verify and find accepted tokens based on logits output and batch Verify and find accepted tokens based on logits output and batch
...@@ -290,14 +294,6 @@ class EagleVerifyInput: ...@@ -290,14 +294,6 @@ class EagleVerifyInput:
) )
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(
logits_output.next_token_logits,
sampling_info,
num_tokens_in_batch=self.draft_token_num,
)
# Apply penalty # Apply penalty
if sampling_info.penalizer_orchestrator.is_required: if sampling_info.penalizer_orchestrator.is_required:
# This is a relaxed version of penalties for speculative decoding. # This is a relaxed version of penalties for speculative decoding.
...@@ -359,13 +355,7 @@ class EagleVerifyInput: ...@@ -359,13 +355,7 @@ class EagleVerifyInput:
draft_probs = torch.zeros( draft_probs = torch.zeros(
target_probs.shape, dtype=torch.float32, device="cuda" target_probs.shape, dtype=torch.float32, device="cuda"
) )
# coins for rejection sampling
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda") coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
# coins for final sampling
coins_for_final_sampling = torch.rand(
(bs,), dtype=torch.float32, device="cuda"
)
tree_speculative_sampling_target_only( tree_speculative_sampling_target_only(
predicts=predict, # mutable predicts=predict, # mutable
accept_index=accept_index, # mutable accept_index=accept_index, # mutable
...@@ -375,7 +365,6 @@ class EagleVerifyInput: ...@@ -375,7 +365,6 @@ class EagleVerifyInput:
retrive_next_token=self.retrive_next_token.to(torch.int32), retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
uniform_samples=coins, uniform_samples=coins,
# uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[ threshold_single=global_server_args_dict[
...@@ -398,8 +387,8 @@ class EagleVerifyInput: ...@@ -398,8 +387,8 @@ class EagleVerifyInput:
spec_steps=self.spec_steps, spec_steps=self.spec_steps,
) )
new_accept_index = []
unfinished_index = [] unfinished_index = []
unfinished_accept_index = []
accept_index_cpu = accept_index.tolist() accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist() predict_cpu = predict.tolist()
has_finished = False has_finished = False
...@@ -407,10 +396,12 @@ class EagleVerifyInput: ...@@ -407,10 +396,12 @@ class EagleVerifyInput:
# Iterate every accepted token and check if req has finished after append the token # Iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots # should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
new_accept_index_ = []
for j, idx in enumerate(accept_index_row): for j, idx in enumerate(accept_index_row):
if idx == -1: if idx == -1:
break break
id = predict_cpu[idx] id = predict_cpu[idx]
# if not found_finished:
req.output_ids.append(id) req.output_ids.append(id)
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
...@@ -419,6 +410,8 @@ class EagleVerifyInput: ...@@ -419,6 +410,8 @@ class EagleVerifyInput:
accept_index[i, j + 1 :] = -1 accept_index[i, j + 1 :] = -1
break break
else: else:
new_accept_index_.append(idx)
# update grammar state
if req.grammar is not None: if req.grammar is not None:
try: try:
req.grammar.accept_token(id) req.grammar.accept_token(id)
...@@ -428,29 +421,20 @@ class EagleVerifyInput: ...@@ -428,29 +421,20 @@ class EagleVerifyInput:
) )
raise e raise e
if not req.finished(): if not req.finished():
new_accept_index.extend(new_accept_index_)
unfinished_index.append(i) unfinished_index.append(i)
if idx == -1:
unfinished_accept_index.append(accept_index[i, :j])
else:
unfinished_accept_index.append(accept_index[i])
req.spec_verify_ct += 1 req.spec_verify_ct += 1
if has_finished: if has_finished:
accept_length = (accept_index != -1).sum(dim=1) - 1 accept_length = (accept_index != -1).sum(dim=1) - 1
# Free the KV cache for unaccepted tokens # Free the KV cache for unaccepted tokens
# TODO: fuse them
accept_index = accept_index[accept_index != -1] accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index] verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False evict_mask[accept_index] = False
if page_size == 1: if page_size != 1:
# TODO: boolean array index leads to a device sync. Remove it.
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
else:
if self.topk == 1:
# Only evict full empty page. Do not evict partial empty page
align_evict_mask_to_page_size[len(batch.seq_lens),]( align_evict_mask_to_page_size[len(batch.seq_lens),](
batch.seq_lens, batch.seq_lens,
evict_mask, evict_mask,
...@@ -458,55 +442,11 @@ class EagleVerifyInput: ...@@ -458,55 +442,11 @@ class EagleVerifyInput:
self.draft_token_num, self.draft_token_num,
next_power_of_2(self.draft_token_num), next_power_of_2(self.draft_token_num),
) )
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
else:
# Shift the accepted tokens to the beginning.
# Only evict the last part
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
batch.seq_lens,
batch.out_cache_loc,
accept_index,
accept_length,
self.draft_token_num,
page_size,
)
to_free_slots = torch.empty(
(to_free_num_slots.sum().item(),),
dtype=torch.int64,
device=to_free_num_slots.device,
)
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8] token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
# to_free_slots: [ 2, 5, 7 8]
# to_free_slots also needs to be page-aligned without the first partial page
#
# split each row of out_cache_loc into two parts.
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
# 2. the second part goes to to_free_slots.
get_target_cache_loc[(bs,)](
tgt_cache_loc,
to_free_slots,
accept_length,
to_free_num_slots,
batch.out_cache_loc,
self.draft_token_num,
next_power_of_2(self.draft_token_num),
next_power_of_2(bs),
)
# Free the kv cache
token_to_kv_pool_allocator.free(to_free_slots)
# Copy the kv cache
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
tgt_cache_loc, src_cache_loc
)
# Construct EagleVerifyOutput # Construct EagleVerifyOutput
if not has_finished: if not has_finished:
if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[accept_index] batch.out_cache_loc = batch.out_cache_loc[accept_index]
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool[(bs,)](
batch.req_pool_indices, batch.req_pool_indices,
...@@ -517,15 +457,14 @@ class EagleVerifyInput: ...@@ -517,15 +457,14 @@ class EagleVerifyInput:
batch.req_to_token_pool.req_to_token.shape[1], batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs), next_power_of_2(bs),
) )
else:
batch.out_cache_loc = tgt_cache_loc
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput() draft_input = EagleDraftInput()
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
draft_input.verified_id = verified_id draft_input.verified_id = verified_id
draft_input.accept_length = accept_length draft_input.accept_length = accept_length
draft_input.accept_length_cpu = accept_length.tolist() draft_input.accept_length_cpu = accept_length_cpu
draft_input.seq_lens_for_draft_extend = batch.seq_lens draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
...@@ -533,11 +472,10 @@ class EagleVerifyInput: ...@@ -533,11 +472,10 @@ class EagleVerifyInput:
draft_input=draft_input, draft_input=draft_input,
logits_output=logits_output, logits_output=logits_output,
verified_id=verified_id, verified_id=verified_id,
accept_length_per_req_cpu=draft_input.accept_length_cpu, accept_length_per_req_cpu=accept_length_cpu,
accepted_indices=accept_index, accepted_indices=accept_index,
) )
else: else:
if page_size == 1 or self.topk == 1:
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool[(bs,)](
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
...@@ -548,51 +486,33 @@ class EagleVerifyInput: ...@@ -548,51 +486,33 @@ class EagleVerifyInput:
next_power_of_2(bs), next_power_of_2(bs),
) )
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist() accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput()
if len(unfinished_accept_index) > 0:
unfinished_accept_index = torch.cat(unfinished_accept_index)
unfinished_index_device = torch.tensor(
unfinished_index, dtype=torch.int64, device=predict.device
)
draft_input_accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
]
if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
else:
batch.out_cache_loc = torch.empty(
len(unfinished_index) + sum(draft_input_accept_length_cpu),
dtype=torch.int64,
device=predict.device,
)
accept_length_filter = create_accept_length_filter(
accept_length,
unfinished_index_device,
batch.seq_lens,
)
filter_finished_cache_loc_kernel[(bs,)](
batch.out_cache_loc,
tgt_cache_loc,
accept_length,
accept_length_filter,
next_power_of_2(bs),
next_power_of_2(self.draft_token_num),
)
draft_input = EagleDraftInput()
if len(new_accept_index) > 0:
new_accept_index = torch.tensor(new_accept_index, device="cuda")
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
draft_input.hidden_states = batch.spec_info.hidden_states[ draft_input.hidden_states = batch.spec_info.hidden_states[
unfinished_accept_index new_accept_index
]
draft_input.verified_id = predict[new_accept_index]
draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
] ]
draft_input.verified_id = predict[unfinished_accept_index]
draft_input.accept_length_cpu = draft_input_accept_length_cpu
draft_input.accept_length = accept_length[unfinished_index_device] draft_input.accept_length = accept_length[unfinished_index_device]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[ draft_input.seq_lens_for_draft_extend = batch.seq_lens[
unfinished_index_device unfinished_index_device
] ]
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[ draft_input.req_pool_indices_for_draft_extend = (
unfinished_index_device batch.req_pool_indices[unfinished_index_device]
] )
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = (
batch.req_pool_indices
)
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
return EagleVerifyOutput( return EagleVerifyOutput(
draft_input=draft_input, draft_input=draft_input,
...@@ -669,75 +589,36 @@ def assign_draft_cache_locs( ...@@ -669,75 +589,36 @@ def assign_draft_cache_locs(
req_pool_indices, req_pool_indices,
req_to_token, req_to_token,
seq_lens, seq_lens,
extend_lens,
num_new_pages_per_topk,
out_cache_loc, out_cache_loc,
pool_len: tl.constexpr, pool_len: tl.constexpr,
topk: tl.constexpr, topk: tl.constexpr,
speculative_num_steps: tl.constexpr, speculative_num_steps: tl.constexpr,
page_size: tl.constexpr, page_size: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
): ):
BLOCK_SIZE: tl.constexpr = 128 BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
kv_start = tl.load(seq_lens + pid)
if page_size == 1 or topk == 1: if page_size == 1 or topk == 1:
copy_len = topk * speculative_num_steps kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
else: else:
bs_offset = tl.arange(0, bs_upper)
copy_len = tl.load(extend_lens + pid)
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
out_cache_ptr = out_cache_loc + cum_copy_len
# Part 1: Copy from out_cache_loc to req_to_token
kv_start = tl.load(seq_lens + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = copy_offset < copy_len
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
if page_size == 1 or topk == 1:
return
# Part 2: Copy the indices for the last partial page
prefix_len = tl.load(seq_lens + pid) prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size last_page_len = prefix_len % page_size
offsets = tl.arange(0, page_size) num_new_page = (
mask = offsets < last_page_len last_page_len + speculative_num_steps + page_size - 1
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid) ) // page_size
prefix_base = token_pool + prefix_len - last_page_len kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
for topk_id in range(topk):
value = tl.load(prefix_base + offsets, mask=mask)
tl.store(
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
value,
mask=mask,
)
# Part 3: Remove the padding in out_cache_loc token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
iter_offest = tl.arange(0, iter_upper)
for topk_id in range(topk): num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
indices = tl.load( for i in range(num_loop):
prefix_base save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
+ topk_id * num_new_pages_per_topk_ * page_size load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
+ last_page_len mask = save_offset < kv_end
+ iter_offest, data = tl.load(out_cache_ptr + load_offset, mask=mask)
mask=iter_offest < speculative_num_steps, tl.store(token_pool + save_offset, data, mask=mask)
)
tl.store(
out_cache_loc
+ pid * topk * speculative_num_steps
+ topk_id * speculative_num_steps
+ iter_offest,
indices,
mask=iter_offest < speculative_num_steps,
)
@triton.jit @triton.jit
...@@ -748,23 +629,20 @@ def generate_draft_decode_kv_indices( ...@@ -748,23 +629,20 @@ def generate_draft_decode_kv_indices(
kv_indices, kv_indices,
kv_indptr, kv_indptr,
positions, positions,
num_seqs: tl.constexpr,
topk: tl.constexpr,
pool_len: tl.constexpr, pool_len: tl.constexpr,
kv_indices_stride: tl.constexpr, kv_indices_stride: tl.constexpr,
kv_indptr_stride: tl.constexpr, kv_indptr_stride: tl.constexpr,
bs_upper: tl.constexpr, bs_upper: tl.constexpr,
iter_upper: tl.constexpr, iter_upper: tl.constexpr,
num_tokens_upper: tl.constexpr, num_tokens_upper: tl.constexpr,
page_size: tl.constexpr,
): ):
BLOCK_SIZE: tl.constexpr = 128 BLOCK_SIZE: tl.constexpr = 128
iters = tl.program_id(axis=0) iters = tl.program_id(axis=0)
bid = tl.program_id(axis=1) bid = tl.program_id(axis=1)
topk_id = tl.program_id(axis=2) topk_id = tl.program_id(axis=2)
num_steps = tl.num_programs(axis=0)
num_seqs = tl.num_programs(axis=1)
topk = tl.num_programs(axis=2)
kv_indices += kv_indices_stride * iters kv_indices += kv_indices_stride * iters
kv_indptr += kv_indptr_stride * iters kv_indptr += kv_indptr_stride * iters
iters += 1 iters += 1
...@@ -774,7 +652,6 @@ def generate_draft_decode_kv_indices( ...@@ -774,7 +652,6 @@ def generate_draft_decode_kv_indices(
seq_len = tl.load(paged_kernel_lens + bid) seq_len = tl.load(paged_kernel_lens + bid)
cum_seq_len = tl.sum(seq_lens) cum_seq_len = tl.sum(seq_lens)
# Update kv_indices
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
kv_ptr = kv_indices + kv_offset kv_ptr = kv_indices + kv_offset
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
...@@ -788,26 +665,10 @@ def generate_draft_decode_kv_indices( ...@@ -788,26 +665,10 @@ def generate_draft_decode_kv_indices(
kv_offset += BLOCK_SIZE kv_offset += BLOCK_SIZE
extend_offset = tl.arange(0, iter_upper) extend_offset = tl.arange(0, iter_upper)
if page_size == 1 or topk == 1:
extend_data = tl.load( extend_data = tl.load(
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper), token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
mask=extend_offset < iters, mask=extend_offset < iters,
) )
else:
prefix_len = seq_len
last_page_len = prefix_len % page_size
num_new_pages_per_topk = (
last_page_len + num_steps + page_size - 1
) // page_size
prefix_base = seq_len // page_size * page_size
start = (
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
)
extend_data = tl.load(
token_pool_ptr + start + extend_offset,
mask=extend_offset < iters,
)
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
# Update kv_indptr # Update kv_indptr
...@@ -846,116 +707,6 @@ def align_evict_mask_to_page_size( ...@@ -846,116 +707,6 @@ def align_evict_mask_to_page_size(
tl.store(evict_mask + bid * num_draft_tokens + i, False) tl.store(evict_mask + bid * num_draft_tokens + i, False)
@triton.jit
def get_target_cache_loc(
tgt_cache_loc,
to_free_slots,
accept_length,
to_free_num_slots,
out_cache_loc,
num_verify_tokens: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
bs_upper: tl.constexpr,
):
bid = tl.program_id(axis=0)
offset = tl.arange(0, num_verify_tokens_upper)
bs_offset = tl.arange(0, bs_upper)
# write the first part to tgt_cache_loc
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
copy_len = tl.load(accept_length + bid) + 1
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
)
tl.store(
tgt_cache_loc + tgt_cache_loc_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
# write the second part to to_free_num_pages
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
to_free_slots_start = tl.sum(to_free_num_slots_all)
copy_len = to_free_num_slots_cur
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
mask=offset < copy_len,
)
tl.store(
to_free_slots + to_free_slots_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
@torch.compile(dynamic=True)
def get_src_tgt_cache_loc(
seq_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
accept_index: torch.Tensor,
accept_length: torch.Tensor,
draft_token_num: int,
page_size: int,
):
src_cache_loc = out_cache_loc[accept_index]
tgt_cache_loc = torch.empty_like(src_cache_loc)
extended_len = seq_lens + draft_token_num
keep_len = torch.minimum(
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
extended_len,
)
to_free_num_slots = extended_len - keep_len
return src_cache_loc, tgt_cache_loc, to_free_num_slots
@triton.jit
def filter_finished_cache_loc_kernel(
out_cache_loc,
tgt_cache_loc,
accept_length,
accept_length_filter,
bs_upper: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
):
bid = tl.program_id(0)
bs_offset = tl.arange(0, bs_upper)
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
old_start = tl.sum(accept_length_all) + bid
accept_length_filter_all = tl.load(
accept_length_filter + bs_offset, mask=bs_offset < bid
)
new_start = tl.sum(accept_length_filter_all)
copy_len = tl.load(accept_length_filter + bid)
copy_offset = tl.arange(0, num_verify_tokens_upper)
value = tl.load(
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
)
tl.store(
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
)
@torch.compile(dynamic=True)
def create_accept_length_filter(
accept_length: torch.Tensor,
unfinished_index_device: torch.Tensor,
seq_lens: torch.Tensor,
):
accept_length_filter = torch.zeros_like(accept_length)
accept_length_filter[unfinished_index_device] = (
accept_length[unfinished_index_device] + 1
)
seq_lens.add_(accept_length + 1)
return accept_length_filter
@torch.compile(dynamic=True) @torch.compile(dynamic=True)
def select_top_k_tokens( def select_top_k_tokens(
i: int, i: int,
...@@ -1005,16 +756,6 @@ def select_top_k_tokens( ...@@ -1005,16 +756,6 @@ def select_top_k_tokens(
return input_ids, hidden_states, scores, tree_info return input_ids, hidden_states, scores, tree_info
def fast_topk_torch(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
max_value, max_index = torch.max(values, dim=dim)
return max_value.unsqueeze(1), max_index.unsqueeze(1)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index( def _generate_simulated_accept_index(
accept_index, accept_index,
predict, predict,
...@@ -1024,7 +765,6 @@ def _generate_simulated_accept_index( ...@@ -1024,7 +765,6 @@ def _generate_simulated_accept_index(
spec_steps, spec_steps,
): ):
simulate_acc_len_float = float(simulate_acc_len) simulate_acc_len_float = float(simulate_acc_len)
if SIMULATE_ACC_METHOD == "multinomial":
simulated_values = torch.normal( simulated_values = torch.normal(
mean=simulate_acc_len_float, mean=simulate_acc_len_float,
std=1.0, std=1.0,
...@@ -1032,27 +772,8 @@ def _generate_simulated_accept_index( ...@@ -1032,27 +772,8 @@ def _generate_simulated_accept_index(
device="cpu", device="cpu",
) )
# clamp simulated values to be between 1 and self.spec_steps # clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
simulate_acc_len = int(simulated_values.round().item()) simulate_acc_len = int(simulated_values.round().item())
elif SIMULATE_ACC_METHOD == "match-expected":
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
lower = int(simulate_acc_len_float // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper:
simulate_acc_len = lower
else:
weight_upper = simulate_acc_len_float - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
simulate_acc_len = lower if sampled_index == 0 else upper
else:
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
accept_indx_first_col = accept_index[:, 0].view(-1, 1) accept_indx_first_col = accept_index[:, 0].view(-1, 1)
sim_accept_index = torch.full( sim_accept_index = torch.full(
...@@ -1143,9 +864,9 @@ def generate_token_bitmask( ...@@ -1143,9 +864,9 @@ def generate_token_bitmask(
""" """
Generate the logit mask for structured output. Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar. Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to We need to perform DFS to figure out:
1. figure out which tokens are accepted by the grammar. 1. which tokens are accepted by the grammar
2. if so, what is the corresponding logit mask. 2. what is the corresponding logit mask.
""" """
num_draft_tokens = draft_tokens_cpu.shape[-1] num_draft_tokens = draft_tokens_cpu.shape[-1]
...@@ -1162,7 +883,6 @@ def generate_token_bitmask( ...@@ -1162,7 +883,6 @@ def generate_token_bitmask(
device="cpu", device="cpu",
) )
grammar = req.grammar grammar = req.grammar
s = time.perf_counter()
traverse_tree( traverse_tree(
retrieve_next_token_cpu[i], retrieve_next_token_cpu[i],
retrieve_next_sibling_cpu[i], retrieve_next_sibling_cpu[i],
...@@ -1172,12 +892,6 @@ def generate_token_bitmask( ...@@ -1172,12 +892,6 @@ def generate_token_bitmask(
i * num_draft_tokens : (i + 1) * num_draft_tokens i * num_draft_tokens : (i + 1) * num_draft_tokens
], ],
) )
tree_traverse_time = time.perf_counter() - s
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
logger.warning(
f"Bit mask generation took {tree_traverse_time} seconds with "
f"grammar: {req.grammar}"
)
verify_input.grammar = grammar verify_input.grammar = grammar
return allocate_token_bitmask return allocate_token_bitmask
...@@ -35,17 +35,11 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -35,17 +35,11 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput, EagleVerifyInput,
EagleVerifyOutput, EagleVerifyOutput,
assign_draft_cache_locs, assign_draft_cache_locs,
fast_topk,
generate_token_bitmask, generate_token_bitmask,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
empty_context,
get_available_gpu_memory,
is_cuda,
next_power_of_2,
)
if is_cuda(): if is_cuda():
from sgl_kernel import segment_packbits from sgl_kernel import segment_packbits
...@@ -158,12 +152,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -158,12 +152,6 @@ class EAGLEWorker(TpModelWorker):
self.init_attention_backend() self.init_attention_backend()
self.init_cuda_graphs() self.init_cuda_graphs()
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
def init_attention_backend(self): def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
if self.server_args.attention_backend == "flashinfer": if self.server_args.attention_backend == "flashinfer":
...@@ -266,7 +254,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -266,7 +254,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
) )
# Capture extend # Capture extend
...@@ -281,7 +269,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -281,7 +269,7 @@ class EAGLEWorker(TpModelWorker):
) )
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
) )
@property @property
...@@ -302,6 +290,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -302,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted, A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens. the batch id (used for overlap schedule), and number of accepted tokens.
""" """
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch) spec_info = self.draft(batch)
...@@ -377,21 +366,14 @@ class EAGLEWorker(TpModelWorker): ...@@ -377,21 +366,14 @@ class EAGLEWorker(TpModelWorker):
) )
# Allocate cache locations # Allocate cache locations
# Layout of the out_cache_loc
# [ topk 0 ] [ topk 1 ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
if self.page_size == 1: if self.page_size == 1:
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
num_seqs * self.speculative_num_steps * self.topk, backup_state=True num_seqs * self.topk * self.speculative_num_steps, backup_state=True
) )
else: else:
if self.topk == 1: if self.topk == 1:
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1( prefix_lens = batch.seq_lens
batch.req_to_token_pool.req_to_token, seq_lens = prefix_lens + self.speculative_num_steps
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
)
extend_num_tokens = num_seqs * self.speculative_num_steps extend_num_tokens = num_seqs * self.speculative_num_steps
else: else:
# In this case, the last partial page needs to be duplicated. # In this case, the last partial page needs to be duplicated.
...@@ -404,33 +386,29 @@ class EAGLEWorker(TpModelWorker): ...@@ -404,33 +386,29 @@ class EAGLEWorker(TpModelWorker):
# "x" means speculative draft tokens # "x" means speculative draft tokens
# "." means padded tokens # "." means padded tokens
# TODO(lmzheng): The current implementation is still a fake support # TODO: fuse these ops
# for page size > 1. In the `assign_draft_cache_locs` below, prefix_lens = batch.seq_lens
# we directly move the indices instead of the real kv cache. last_page_lens = prefix_lens % self.page_size
# This only works when the kernel backend runs with page size = 1. num_new_pages = (
# If the kernel backend runs with page size > 1, we need to last_page_lens + self.speculative_num_steps + self.page_size - 1
# duplicate the real KV cache. The overhead of duplicating KV ) // self.page_size
# cache seems okay because the draft KV cache only has one layer. seq_lens = (
# see a related copy operation in MHATokenToKVPool::move_kv_cache. prefix_lens // self.page_size * self.page_size
+ num_new_pages * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
raise NotImplementedError(
"page_size > 1 and top_k > 1 are not supported."
)
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments
# 2. Modify generate_draft_decode_kv_indices accordingly
( last_loc = get_last_loc(
prefix_lens,
seq_lens,
last_loc,
self.num_new_pages_per_topk,
self.extend_lens,
) = get_last_loc_large_page_size_large_top_k(
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.req_pool_indices, batch.req_pool_indices,
batch.seq_lens, prefix_lens,
self.speculative_num_steps,
self.topk,
self.page_size,
) )
# TODO(lmzheng): remove this device sync
extend_num_tokens = torch.sum(self.extend_lens).item()
out_cache_loc, token_to_kv_pool_state_backup = ( out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend( batch.alloc_paged_token_slots_extend(
prefix_lens, prefix_lens,
...@@ -445,31 +423,19 @@ class EAGLEWorker(TpModelWorker): ...@@ -445,31 +423,19 @@ class EAGLEWorker(TpModelWorker):
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.seq_lens, batch.seq_lens,
self.extend_lens,
self.num_new_pages_per_topk,
out_cache_loc, out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1], batch.req_to_token_pool.req_to_token.shape[1],
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.page_size, self.page_size,
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
) )
if self.page_size > 1 and self.topk > 1:
# Remove padded slots
out_cache_loc = out_cache_loc[
: num_seqs * self.topk * self.speculative_num_steps
]
batch.out_cache_loc = out_cache_loc batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
# Get forward batch
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -538,13 +504,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -538,13 +504,6 @@ class EAGLEWorker(TpModelWorker):
if self.hot_token_id is not None: if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index] topk_index = self.hot_token_id[topk_index]
out_cache_loc = out_cache_loc.reshape(
forward_batch.batch_size, self.topk, self.speculative_num_steps
)
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
self.speculative_num_steps, -1
)
# Return values # Return values
score_list: List[torch.Tensor] = [] score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = [] token_list: List[torch.Tensor] = []
...@@ -566,7 +525,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -566,7 +525,10 @@ class EAGLEWorker(TpModelWorker):
# Set inputs # Set inputs
forward_batch.input_ids = input_ids forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[i:] out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
forward_batch.out_cache_loc = out_cache_loc[
:, self.topk * i : self.topk * (i + 1)
].flatten()
forward_batch.positions.add_(1) forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states spec_info.hidden_states = hidden_states
...@@ -624,7 +586,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -624,7 +586,7 @@ class EAGLEWorker(TpModelWorker):
if vocab_mask is not None: if vocab_mask is not None:
assert spec_info.grammar is not None assert spec_info.grammar is not None
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device) vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage # otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results # and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None batch.sampling_info.vocab_mask = None
...@@ -645,13 +607,13 @@ class EAGLEWorker(TpModelWorker): ...@@ -645,13 +607,13 @@ class EAGLEWorker(TpModelWorker):
] ]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards. # Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input batch.spec_info = res.draft_input
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch, can_run_cuda_graph return logits_output, res, model_worker_batch, can_run_cuda_graph
def add_logprob_values( def add_logprob_values(
...@@ -664,16 +626,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -664,16 +626,8 @@ class EAGLEWorker(TpModelWorker):
logits_output = res.logits_output logits_output = res.logits_output
top_logprobs_nums = batch.top_logprobs_nums top_logprobs_nums = batch.top_logprobs_nums
token_ids_logprobs = batch.token_ids_logprobs token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits)
temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens]
logprobs = torch.nn.functional.log_softmax( logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1 logits_output.next_token_logits, dim=-1
) )
batch_next_token_ids = res.verified_id batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
...@@ -708,7 +662,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -708,7 +662,7 @@ class EAGLEWorker(TpModelWorker):
pt = 0 pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist() next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist() verified_ids = batch_next_token_ids.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True): for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for _ in range(num_tokens): for _ in range(num_tokens):
if req.return_logprob: if req.return_logprob:
req.output_token_logprobs_val.append(next_token_logprobs[pt]) req.output_token_logprobs_val.append(next_token_logprobs[pt])
...@@ -736,6 +690,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -736,6 +690,7 @@ class EAGLEWorker(TpModelWorker):
hidden_states: Hidden states from the target model forward hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward. next_token_ids: Next token ids generated from the target forward.
""" """
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
batch.spec_info = EagleDraftInput( batch.spec_info = EagleDraftInput(
hidden_states=hidden_states, hidden_states=hidden_states,
verified_id=next_token_ids, verified_id=next_token_ids,
...@@ -746,6 +701,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -746,6 +701,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu seq_lens_cpu_cache=seq_lens_cpu
) )
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -768,7 +724,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -768,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
batch, batch,
self.speculative_num_steps, self.speculative_num_steps,
) )
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -832,47 +790,3 @@ def load_token_map(token_map_path: str) -> List[int]: ...@@ -832,47 +790,3 @@ def load_token_map(token_map_path: str) -> List[int]:
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True) hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32) return torch.tensor(hot_token_id, dtype=torch.int32)
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens,
speculative_num_steps: int,
):
prefix_lens = seq_lens
seq_lens = prefix_lens + speculative_num_steps
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_large_top_k(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
speculative_num_steps: int,
topk: int,
page_size: int,
):
prefix_lens = seq_lens
last_page_lens = prefix_lens % page_size
num_new_pages_per_topk = (
last_page_lens + speculative_num_steps + page_size - 1
) // page_size
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
page_size * topk
)
extend_lens = seq_lens - prefix_lens
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
...@@ -441,71 +441,5 @@ class TestEAGLEServerTriton(TestEAGLEServer): ...@@ -441,71 +441,5 @@ class TestEAGLEServerTriton(TestEAGLEServer):
) )
class TestEAGLEServerPageSize(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
6,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
"--attention-backend",
"flashinfer",
],
)
class TestEAGLEServerPageSizeTopk(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
8,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
"--attention-backend",
"flashinfer",
],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment