Unverified Commit fff10809 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "[EAGLE] Refactor code for page size > 1 & more simplifications" (#7210)

parent 5f1ab327
......@@ -1049,13 +1049,14 @@ class FlashInferMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
assert forward_batch.spec_info is not None
......
......@@ -789,7 +789,6 @@ class FlashInferMLAMultiStepDraftBackend:
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
def common_template(
self,
......@@ -810,13 +809,14 @@ class FlashInferMLAMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
assert forward_batch.spec_info is not None
......
......@@ -784,13 +784,14 @@ class TritonMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
for i in range(self.speculative_num_steps):
......
......@@ -294,19 +294,6 @@ class MHATokenToKVPool(KVCache):
for _ in range(self.layer_num)
]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
dtype=torch.uint64,
device=self.device,
)
self.data_strides = torch.tensor(
[
np.prod(x.shape[1:]) * x.dtype.itemsize
for x in self.k_buffer + self.v_buffer
],
device=self.device,
)
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
......@@ -464,16 +451,6 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
self.data_ptrs,
self.data_strides,
tgt_loc,
src_loc,
len(tgt_loc),
next_power_of_2(len(tgt_loc)),
)
@triton.jit
def set_mla_kv_buffer_kernel(
......@@ -764,41 +741,3 @@ class DoubleSparseTokenToKVPool(KVCache):
def transfer_per_layer(self, indices, flat_data, layer_id):
pass
@triton.jit
def copy_all_layer_kv_cache(
data_ptrs,
strides,
tgt_loc_ptr,
src_loc_ptr,
num_locs,
num_locs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
bid = tl.program_id(0)
stride = tl.load(strides + bid)
data_ptr = tl.load(data_ptrs + bid)
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
num_locs_offset = tl.arange(0, num_locs_upper)
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
# because this copy is an inplace operation.
num_loop = tl.cdiv(stride, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
value = tl.load(
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
)
tl.store(
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
value,
mask=mask,
)
......@@ -67,6 +67,8 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
all_padding_lens: Optional[torch.Tensor] = None
def prepare_for_extend(self, batch: ScheduleBatch):
# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)
......@@ -91,7 +93,6 @@ class EagleDraftInput:
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
batch.return_logprob = False
batch.return_hidden_states = False
self.capture_hidden_mode = CaptureHiddenMode.LAST
self.accept_length.add_(1)
......@@ -115,8 +116,10 @@ class EagleDraftInput:
req_to_token: torch.Tensor,
):
bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
......@@ -136,6 +139,7 @@ class EagleDraftInput:
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, qo_indptr, None
def filter_batch(self, new_indices: torch.Tensor):
......@@ -266,7 +270,7 @@ class EagleVerifyInput:
logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
vocab_mask: Optional[torch.Tensor] = None, # For grammar
vocab_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Verify and find accepted tokens based on logits output and batch
......@@ -290,14 +294,6 @@ class EagleVerifyInput:
)
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(
logits_output.next_token_logits,
sampling_info,
num_tokens_in_batch=self.draft_token_num,
)
# Apply penalty
if sampling_info.penalizer_orchestrator.is_required:
# This is a relaxed version of penalties for speculative decoding.
......@@ -359,13 +355,7 @@ class EagleVerifyInput:
draft_probs = torch.zeros(
target_probs.shape, dtype=torch.float32, device="cuda"
)
# coins for rejection sampling
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
# coins for final sampling
coins_for_final_sampling = torch.rand(
(bs,), dtype=torch.float32, device="cuda"
)
tree_speculative_sampling_target_only(
predicts=predict, # mutable
accept_index=accept_index, # mutable
......@@ -375,7 +365,6 @@ class EagleVerifyInput:
retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
uniform_samples=coins,
# uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
......@@ -398,8 +387,8 @@ class EagleVerifyInput:
spec_steps=self.spec_steps,
)
new_accept_index = []
unfinished_index = []
unfinished_accept_index = []
accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist()
has_finished = False
......@@ -407,10 +396,12 @@ class EagleVerifyInput:
# Iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
new_accept_index_ = []
for j, idx in enumerate(accept_index_row):
if idx == -1:
break
id = predict_cpu[idx]
# if not found_finished:
req.output_ids.append(id)
req.check_finished()
if req.finished():
......@@ -419,6 +410,8 @@ class EagleVerifyInput:
accept_index[i, j + 1 :] = -1
break
else:
new_accept_index_.append(idx)
# update grammar state
if req.grammar is not None:
try:
req.grammar.accept_token(id)
......@@ -428,104 +421,50 @@ class EagleVerifyInput:
)
raise e
if not req.finished():
new_accept_index.extend(new_accept_index_)
unfinished_index.append(i)
if idx == -1:
unfinished_accept_index.append(accept_index[i, :j])
else:
unfinished_accept_index.append(accept_index[i])
req.spec_verify_ct += 1
if has_finished:
accept_length = (accept_index != -1).sum(dim=1) - 1
# Free the KV cache for unaccepted tokens
# TODO: fuse them
accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
if page_size == 1:
# TODO: boolean array index leads to a device sync. Remove it.
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
else:
if self.topk == 1:
# Only evict full empty page. Do not evict partial empty page
align_evict_mask_to_page_size[len(batch.seq_lens),](
batch.seq_lens,
evict_mask,
page_size,
self.draft_token_num,
next_power_of_2(self.draft_token_num),
)
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
else:
# Shift the accepted tokens to the beginning.
# Only evict the last part
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
batch.seq_lens,
batch.out_cache_loc,
accept_index,
accept_length,
self.draft_token_num,
page_size,
)
to_free_slots = torch.empty(
(to_free_num_slots.sum().item(),),
dtype=torch.int64,
device=to_free_num_slots.device,
)
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
# to_free_slots: [ 2, 5, 7 8]
# to_free_slots also needs to be page-aligned without the first partial page
#
# split each row of out_cache_loc into two parts.
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
# 2. the second part goes to to_free_slots.
get_target_cache_loc[(bs,)](
tgt_cache_loc,
to_free_slots,
accept_length,
to_free_num_slots,
batch.out_cache_loc,
self.draft_token_num,
next_power_of_2(self.draft_token_num),
next_power_of_2(bs),
)
# Free the kv cache
token_to_kv_pool_allocator.free(to_free_slots)
if page_size != 1:
align_evict_mask_to_page_size[len(batch.seq_lens),](
batch.seq_lens,
evict_mask,
page_size,
self.draft_token_num,
next_power_of_2(self.draft_token_num),
)
# Copy the kv cache
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
tgt_cache_loc, src_cache_loc
)
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
# Construct EagleVerifyOutput
if not has_finished:
if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[accept_index]
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
else:
batch.out_cache_loc = tgt_cache_loc
batch.out_cache_loc = batch.out_cache_loc[accept_index]
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput()
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
draft_input.verified_id = verified_id
draft_input.accept_length = accept_length
draft_input.accept_length_cpu = accept_length.tolist()
draft_input.accept_length_cpu = accept_length_cpu
draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
......@@ -533,66 +472,47 @@ class EagleVerifyInput:
draft_input=draft_input,
logits_output=logits_output,
verified_id=verified_id,
accept_length_per_req_cpu=draft_input.accept_length_cpu,
accept_length_per_req_cpu=accept_length_cpu,
accepted_indices=accept_index,
)
else:
if page_size == 1 or self.topk == 1:
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc[accept_index],
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc[accept_index],
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput()
if len(unfinished_accept_index) > 0:
unfinished_accept_index = torch.cat(unfinished_accept_index)
unfinished_index_device = torch.tensor(
unfinished_index, dtype=torch.int64, device=predict.device
)
draft_input_accept_length_cpu = [
if len(new_accept_index) > 0:
new_accept_index = torch.tensor(new_accept_index, device="cuda")
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
draft_input.hidden_states = batch.spec_info.hidden_states[
new_accept_index
]
draft_input.verified_id = predict[new_accept_index]
draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
]
if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
else:
batch.out_cache_loc = torch.empty(
len(unfinished_index) + sum(draft_input_accept_length_cpu),
dtype=torch.int64,
device=predict.device,
)
accept_length_filter = create_accept_length_filter(
accept_length,
unfinished_index_device,
batch.seq_lens,
draft_input.accept_length = accept_length[unfinished_index_device]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
unfinished_index_device
]
draft_input.req_pool_indices_for_draft_extend = (
batch.req_pool_indices[unfinished_index_device]
)
filter_finished_cache_loc_kernel[(bs,)](
batch.out_cache_loc,
tgt_cache_loc,
accept_length,
accept_length_filter,
next_power_of_2(bs),
next_power_of_2(self.draft_token_num),
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = (
batch.req_pool_indices
)
draft_input.hidden_states = batch.spec_info.hidden_states[
unfinished_accept_index
]
draft_input.verified_id = predict[unfinished_accept_index]
draft_input.accept_length_cpu = draft_input_accept_length_cpu
draft_input.accept_length = accept_length[unfinished_index_device]
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
unfinished_index_device
]
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
unfinished_index_device
]
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
return EagleVerifyOutput(
draft_input=draft_input,
......@@ -669,75 +589,36 @@ def assign_draft_cache_locs(
req_pool_indices,
req_to_token,
seq_lens,
extend_lens,
num_new_pages_per_topk,
out_cache_loc,
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
page_size: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0)
kv_start = tl.load(seq_lens + pid)
if page_size == 1 or topk == 1:
copy_len = topk * speculative_num_steps
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
else:
bs_offset = tl.arange(0, bs_upper)
copy_len = tl.load(extend_lens + pid)
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
out_cache_ptr = out_cache_loc + cum_copy_len
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
num_new_page = (
last_page_len + speculative_num_steps + page_size - 1
) // page_size
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
# Part 1: Copy from out_cache_loc to req_to_token
kv_start = tl.load(seq_lens + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = copy_offset < copy_len
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
if page_size == 1 or topk == 1:
return
# Part 2: Copy the indices for the last partial page
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
offsets = tl.arange(0, page_size)
mask = offsets < last_page_len
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
prefix_base = token_pool + prefix_len - last_page_len
for topk_id in range(topk):
value = tl.load(prefix_base + offsets, mask=mask)
tl.store(
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
value,
mask=mask,
)
# Part 3: Remove the padding in out_cache_loc
iter_offest = tl.arange(0, iter_upper)
for topk_id in range(topk):
indices = tl.load(
prefix_base
+ topk_id * num_new_pages_per_topk_ * page_size
+ last_page_len
+ iter_offest,
mask=iter_offest < speculative_num_steps,
)
tl.store(
out_cache_loc
+ pid * topk * speculative_num_steps
+ topk_id * speculative_num_steps
+ iter_offest,
indices,
mask=iter_offest < speculative_num_steps,
)
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
for i in range(num_loop):
save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = save_offset < kv_end
data = tl.load(out_cache_ptr + load_offset, mask=mask)
tl.store(token_pool + save_offset, data, mask=mask)
@triton.jit
......@@ -748,23 +629,20 @@ def generate_draft_decode_kv_indices(
kv_indices,
kv_indptr,
positions,
num_seqs: tl.constexpr,
topk: tl.constexpr,
pool_len: tl.constexpr,
kv_indices_stride: tl.constexpr,
kv_indptr_stride: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
num_tokens_upper: tl.constexpr,
page_size: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
iters = tl.program_id(axis=0)
bid = tl.program_id(axis=1)
topk_id = tl.program_id(axis=2)
num_steps = tl.num_programs(axis=0)
num_seqs = tl.num_programs(axis=1)
topk = tl.num_programs(axis=2)
kv_indices += kv_indices_stride * iters
kv_indptr += kv_indptr_stride * iters
iters += 1
......@@ -774,7 +652,6 @@ def generate_draft_decode_kv_indices(
seq_len = tl.load(paged_kernel_lens + bid)
cum_seq_len = tl.sum(seq_lens)
# Update kv_indices
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
kv_ptr = kv_indices + kv_offset
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
......@@ -788,26 +665,10 @@ def generate_draft_decode_kv_indices(
kv_offset += BLOCK_SIZE
extend_offset = tl.arange(0, iter_upper)
if page_size == 1 or topk == 1:
extend_data = tl.load(
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
mask=extend_offset < iters,
)
else:
prefix_len = seq_len
last_page_len = prefix_len % page_size
num_new_pages_per_topk = (
last_page_len + num_steps + page_size - 1
) // page_size
prefix_base = seq_len // page_size * page_size
start = (
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
)
extend_data = tl.load(
token_pool_ptr + start + extend_offset,
mask=extend_offset < iters,
)
extend_data = tl.load(
token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
mask=extend_offset < iters,
)
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
# Update kv_indptr
......@@ -846,116 +707,6 @@ def align_evict_mask_to_page_size(
tl.store(evict_mask + bid * num_draft_tokens + i, False)
@triton.jit
def get_target_cache_loc(
tgt_cache_loc,
to_free_slots,
accept_length,
to_free_num_slots,
out_cache_loc,
num_verify_tokens: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
bs_upper: tl.constexpr,
):
bid = tl.program_id(axis=0)
offset = tl.arange(0, num_verify_tokens_upper)
bs_offset = tl.arange(0, bs_upper)
# write the first part to tgt_cache_loc
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
copy_len = tl.load(accept_length + bid) + 1
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
)
tl.store(
tgt_cache_loc + tgt_cache_loc_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
# write the second part to to_free_num_pages
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
to_free_slots_start = tl.sum(to_free_num_slots_all)
copy_len = to_free_num_slots_cur
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
mask=offset < copy_len,
)
tl.store(
to_free_slots + to_free_slots_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
@torch.compile(dynamic=True)
def get_src_tgt_cache_loc(
seq_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
accept_index: torch.Tensor,
accept_length: torch.Tensor,
draft_token_num: int,
page_size: int,
):
src_cache_loc = out_cache_loc[accept_index]
tgt_cache_loc = torch.empty_like(src_cache_loc)
extended_len = seq_lens + draft_token_num
keep_len = torch.minimum(
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
extended_len,
)
to_free_num_slots = extended_len - keep_len
return src_cache_loc, tgt_cache_loc, to_free_num_slots
@triton.jit
def filter_finished_cache_loc_kernel(
out_cache_loc,
tgt_cache_loc,
accept_length,
accept_length_filter,
bs_upper: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
):
bid = tl.program_id(0)
bs_offset = tl.arange(0, bs_upper)
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
old_start = tl.sum(accept_length_all) + bid
accept_length_filter_all = tl.load(
accept_length_filter + bs_offset, mask=bs_offset < bid
)
new_start = tl.sum(accept_length_filter_all)
copy_len = tl.load(accept_length_filter + bid)
copy_offset = tl.arange(0, num_verify_tokens_upper)
value = tl.load(
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
)
tl.store(
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
)
@torch.compile(dynamic=True)
def create_accept_length_filter(
accept_length: torch.Tensor,
unfinished_index_device: torch.Tensor,
seq_lens: torch.Tensor,
):
accept_length_filter = torch.zeros_like(accept_length)
accept_length_filter[unfinished_index_device] = (
accept_length[unfinished_index_device] + 1
)
seq_lens.add_(accept_length + 1)
return accept_length_filter
@torch.compile(dynamic=True)
def select_top_k_tokens(
i: int,
......@@ -1005,16 +756,6 @@ def select_top_k_tokens(
return input_ids, hidden_states, scores, tree_info
def fast_topk_torch(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
max_value, max_index = torch.max(values, dim=dim)
return max_value.unsqueeze(1), max_index.unsqueeze(1)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index(
accept_index,
predict,
......@@ -1024,35 +765,15 @@ def _generate_simulated_accept_index(
spec_steps,
):
simulate_acc_len_float = float(simulate_acc_len)
if SIMULATE_ACC_METHOD == "multinomial":
simulated_values = torch.normal(
mean=simulate_acc_len_float,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item())
elif SIMULATE_ACC_METHOD == "match-expected":
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
lower = int(simulate_acc_len_float // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper:
simulate_acc_len = lower
else:
weight_upper = simulate_acc_len_float - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
simulate_acc_len = lower if sampled_index == 0 else upper
else:
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
simulated_values = torch.normal(
mean=simulate_acc_len_float,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
simulate_acc_len = int(simulated_values.round().item())
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
sim_accept_index = torch.full(
......@@ -1143,9 +864,9 @@ def generate_token_bitmask(
"""
Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to
1. figure out which tokens are accepted by the grammar.
2. if so, what is the corresponding logit mask.
We need to perform DFS to figure out:
1. which tokens are accepted by the grammar
2. what is the corresponding logit mask.
"""
num_draft_tokens = draft_tokens_cpu.shape[-1]
......@@ -1162,7 +883,6 @@ def generate_token_bitmask(
device="cpu",
)
grammar = req.grammar
s = time.perf_counter()
traverse_tree(
retrieve_next_token_cpu[i],
retrieve_next_sibling_cpu[i],
......@@ -1172,12 +892,6 @@ def generate_token_bitmask(
i * num_draft_tokens : (i + 1) * num_draft_tokens
],
)
tree_traverse_time = time.perf_counter() - s
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
logger.warning(
f"Bit mask generation took {tree_traverse_time} seconds with "
f"grammar: {req.grammar}"
)
verify_input.grammar = grammar
return allocate_token_bitmask
......@@ -35,17 +35,11 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput,
EagleVerifyOutput,
assign_draft_cache_locs,
fast_topk,
generate_token_bitmask,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
is_cuda,
next_power_of_2,
)
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
if is_cuda():
from sgl_kernel import segment_packbits
......@@ -158,12 +152,6 @@ class EAGLEWorker(TpModelWorker):
self.init_attention_backend()
self.init_cuda_graphs()
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
if self.server_args.attention_backend == "flashinfer":
......@@ -266,7 +254,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
# Capture extend
......@@ -281,7 +269,7 @@ class EAGLEWorker(TpModelWorker):
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
@property
......@@ -302,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
......@@ -377,21 +366,14 @@ class EAGLEWorker(TpModelWorker):
)
# Allocate cache locations
# Layout of the out_cache_loc
# [ topk 0 ] [ topk 1 ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
if self.page_size == 1:
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
)
else:
if self.topk == 1:
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
)
prefix_lens = batch.seq_lens
seq_lens = prefix_lens + self.speculative_num_steps
extend_num_tokens = num_seqs * self.speculative_num_steps
else:
# In this case, the last partial page needs to be duplicated.
......@@ -404,33 +386,29 @@ class EAGLEWorker(TpModelWorker):
# "x" means speculative draft tokens
# "." means padded tokens
# TODO(lmzheng): The current implementation is still a fake support
# for page size > 1. In the `assign_draft_cache_locs` below,
# we directly move the indices instead of the real kv cache.
# This only works when the kernel backend runs with page size = 1.
# If the kernel backend runs with page size > 1, we need to
# duplicate the real KV cache. The overhead of duplicating KV
# cache seems okay because the draft KV cache only has one layer.
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
(
prefix_lens,
seq_lens,
last_loc,
self.num_new_pages_per_topk,
self.extend_lens,
) = get_last_loc_large_page_size_large_top_k(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
self.topk,
self.page_size,
# TODO: fuse these ops
prefix_lens = batch.seq_lens
last_page_lens = prefix_lens % self.page_size
num_new_pages = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens = (
prefix_lens // self.page_size * self.page_size
+ num_new_pages * (self.page_size * self.topk)
)
# TODO(lmzheng): remove this device sync
extend_num_tokens = torch.sum(self.extend_lens).item()
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
raise NotImplementedError(
"page_size > 1 and top_k > 1 are not supported."
)
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments
# 2. Modify generate_draft_decode_kv_indices accordingly
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend(
prefix_lens,
......@@ -445,31 +423,19 @@ class EAGLEWorker(TpModelWorker):
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
self.extend_lens,
self.num_new_pages_per_topk,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
self.page_size,
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
)
if self.page_size > 1 and self.topk > 1:
# Remove padded slots
out_cache_loc = out_cache_loc[
: num_seqs * self.topk * self.speculative_num_steps
]
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# Get forward batch
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -538,13 +504,6 @@ class EAGLEWorker(TpModelWorker):
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
out_cache_loc = out_cache_loc.reshape(
forward_batch.batch_size, self.topk, self.speculative_num_steps
)
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
self.speculative_num_steps, -1
)
# Return values
score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = []
......@@ -566,7 +525,10 @@ class EAGLEWorker(TpModelWorker):
# Set inputs
forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[i:]
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
forward_batch.out_cache_loc = out_cache_loc[
:, self.topk * i : self.topk * (i + 1)
].flatten()
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
......@@ -624,7 +586,7 @@ class EAGLEWorker(TpModelWorker):
if vocab_mask is not None:
assert spec_info.grammar is not None
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
# otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None
......@@ -645,13 +607,13 @@ class EAGLEWorker(TpModelWorker):
]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch, can_run_cuda_graph
def add_logprob_values(
......@@ -664,16 +626,8 @@ class EAGLEWorker(TpModelWorker):
logits_output = res.logits_output
top_logprobs_nums = batch.top_logprobs_nums
token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits)
temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens]
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1
logits_output.next_token_logits, dim=-1
)
batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
......@@ -708,7 +662,7 @@ class EAGLEWorker(TpModelWorker):
pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for _ in range(num_tokens):
if req.return_logprob:
req.output_token_logprobs_val.append(next_token_logprobs[pt])
......@@ -736,6 +690,7 @@ class EAGLEWorker(TpModelWorker):
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
......@@ -746,6 +701,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -768,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
batch,
self.speculative_num_steps,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -832,47 +790,3 @@ def load_token_map(token_map_path: str) -> List[int]:
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32)
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens,
speculative_num_steps: int,
):
prefix_lens = seq_lens
seq_lens = prefix_lens + speculative_num_steps
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_large_top_k(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
speculative_num_steps: int,
topk: int,
page_size: int,
):
prefix_lens = seq_lens
last_page_lens = prefix_lens % page_size
num_new_pages_per_topk = (
last_page_lens + speculative_num_steps + page_size - 1
) // page_size
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
page_size * topk
)
extend_lens = seq_lens - prefix_lens
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
......@@ -441,71 +441,5 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
class TestEAGLEServerPageSize(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
6,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
"--attention-backend",
"flashinfer",
],
)
class TestEAGLEServerPageSizeTopk(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
8,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
"--attention-backend",
"flashinfer",
],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment