Unverified Commit b1286a11 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[EAGLE] Refactor code for page size > 1 & more simplifications (#7213)

parent 21615cc3
...@@ -1049,14 +1049,13 @@ class FlashInferMultiStepDraftBackend: ...@@ -1049,14 +1049,13 @@ class FlashInferMultiStepDraftBackend:
kv_indices_buffer, kv_indices_buffer,
self.kv_indptr, self.kv_indptr,
forward_batch.positions, forward_batch.positions,
num_seqs,
self.topk,
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
next_power_of_2(num_seqs), next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs), next_power_of_2(bs),
self.page_size,
) )
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
......
...@@ -789,6 +789,7 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -789,6 +789,7 @@ class FlashInferMLAMultiStepDraftBackend:
# Cached variables for generate_draft_decode_kv_indices # Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
def common_template( def common_template(
self, self,
...@@ -809,14 +810,13 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -809,14 +810,13 @@ class FlashInferMLAMultiStepDraftBackend:
kv_indices_buffer, kv_indices_buffer,
self.kv_indptr, self.kv_indptr,
forward_batch.positions, forward_batch.positions,
num_seqs,
self.topk,
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
next_power_of_2(num_seqs), next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs), next_power_of_2(bs),
self.page_size,
) )
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
......
...@@ -2,9 +2,6 @@ from __future__ import annotations ...@@ -2,9 +2,6 @@ from __future__ import annotations
""" """
Support attention backend for FlashMLA. Support attention backend for FlashMLA.
#TODO
Enable speculative sampling in FlashMLA
""" """
from dataclasses import dataclass from dataclasses import dataclass
......
...@@ -784,14 +784,13 @@ class TritonMultiStepDraftBackend: ...@@ -784,14 +784,13 @@ class TritonMultiStepDraftBackend:
kv_indices_buffer, kv_indices_buffer,
self.kv_indptr, self.kv_indptr,
forward_batch.positions, forward_batch.positions,
num_seqs,
self.topk,
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
next_power_of_2(num_seqs), next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs), next_power_of_2(bs),
self.page_size,
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
......
...@@ -294,6 +294,19 @@ class MHATokenToKVPool(KVCache): ...@@ -294,6 +294,19 @@ class MHATokenToKVPool(KVCache):
for _ in range(self.layer_num) for _ in range(self.layer_num)
] ]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
dtype=torch.uint64,
device=self.device,
)
self.data_strides = torch.tensor(
[
np.prod(x.shape[1:]) * x.dtype.itemsize
for x in self.k_buffer + self.v_buffer
],
device=self.device,
)
def _clear_buffers(self): def _clear_buffers(self):
del self.k_buffer del self.k_buffer
del self.v_buffer del self.v_buffer
...@@ -451,6 +464,16 @@ class MHATokenToKVPool(KVCache): ...@@ -451,6 +464,16 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[layer_id - self.start_layer][loc] = cache_k self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
self.data_ptrs,
self.data_strides,
tgt_loc,
src_loc,
len(tgt_loc),
next_power_of_2(len(tgt_loc)),
)
@triton.jit @triton.jit
def set_mla_kv_buffer_kernel( def set_mla_kv_buffer_kernel(
...@@ -741,3 +764,41 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -741,3 +764,41 @@ class DoubleSparseTokenToKVPool(KVCache):
def transfer_per_layer(self, indices, flat_data, layer_id): def transfer_per_layer(self, indices, flat_data, layer_id):
pass pass
@triton.jit
def copy_all_layer_kv_cache(
data_ptrs,
strides,
tgt_loc_ptr,
src_loc_ptr,
num_locs,
num_locs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
bid = tl.program_id(0)
stride = tl.load(strides + bid)
data_ptr = tl.load(data_ptrs + bid)
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
num_locs_offset = tl.arange(0, num_locs_upper)
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
# because this copy is an inplace operation.
num_loop = tl.cdiv(stride, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
value = tl.load(
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
)
tl.store(
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
value,
mask=mask,
)
...@@ -35,11 +35,17 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -35,11 +35,17 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput, EagleVerifyInput,
EagleVerifyOutput, EagleVerifyOutput,
assign_draft_cache_locs, assign_draft_cache_locs,
fast_topk,
generate_token_bitmask, generate_token_bitmask,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
is_cuda,
next_power_of_2,
)
if is_cuda(): if is_cuda():
from sgl_kernel import segment_packbits from sgl_kernel import segment_packbits
...@@ -152,6 +158,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -152,6 +158,12 @@ class EAGLEWorker(TpModelWorker):
self.init_attention_backend() self.init_attention_backend()
self.init_cuda_graphs() self.init_cuda_graphs()
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
def init_attention_backend(self): def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
if self.server_args.attention_backend == "flashinfer": if self.server_args.attention_backend == "flashinfer":
...@@ -254,7 +266,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -254,7 +266,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
) )
# Capture extend # Capture extend
...@@ -269,7 +281,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -269,7 +281,7 @@ class EAGLEWorker(TpModelWorker):
) )
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
) )
@property @property
...@@ -290,7 +302,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -290,7 +302,6 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted, A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens. the batch id (used for overlap schedule), and number of accepted tokens.
""" """
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch) spec_info = self.draft(batch)
...@@ -366,14 +377,21 @@ class EAGLEWorker(TpModelWorker): ...@@ -366,14 +377,21 @@ class EAGLEWorker(TpModelWorker):
) )
# Allocate cache locations # Allocate cache locations
# Layout of the out_cache_loc
# [ topk 0 ] [ topk 1 ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
if self.page_size == 1: if self.page_size == 1:
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps, backup_state=True num_seqs * self.speculative_num_steps * self.topk, backup_state=True
) )
else: else:
if self.topk == 1: if self.topk == 1:
prefix_lens = batch.seq_lens prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
seq_lens = prefix_lens + self.speculative_num_steps batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
)
extend_num_tokens = num_seqs * self.speculative_num_steps extend_num_tokens = num_seqs * self.speculative_num_steps
else: else:
# In this case, the last partial page needs to be duplicated. # In this case, the last partial page needs to be duplicated.
...@@ -386,29 +404,33 @@ class EAGLEWorker(TpModelWorker): ...@@ -386,29 +404,33 @@ class EAGLEWorker(TpModelWorker):
# "x" means speculative draft tokens # "x" means speculative draft tokens
# "." means padded tokens # "." means padded tokens
# TODO: fuse these ops # TODO(lmzheng): The current implementation is still a fake support
prefix_lens = batch.seq_lens # for page size > 1. In the `assign_draft_cache_locs` below,
last_page_lens = prefix_lens % self.page_size # we directly move the indices instead of the real kv cache.
num_new_pages = ( # This only works when the kernel backend runs with page size = 1.
last_page_lens + self.speculative_num_steps + self.page_size - 1 # If the kernel backend runs with page size > 1, we need to
) // self.page_size # duplicate the real KV cache. The overhead of duplicating KV
seq_lens = ( # cache seems okay because the draft KV cache only has one layer.
prefix_lens // self.page_size * self.page_size # see a related copy operation in MHATokenToKVPool::move_kv_cache.
+ num_new_pages * (self.page_size * self.topk)
) (
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item() prefix_lens,
raise NotImplementedError( seq_lens,
"page_size > 1 and top_k > 1 are not supported." last_loc,
self.num_new_pages_per_topk,
self.extend_lens,
) = get_last_loc_large_page_size_large_top_k(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
self.topk,
self.page_size,
) )
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments # TODO(lmzheng): remove this device sync
# 2. Modify generate_draft_decode_kv_indices accordingly extend_num_tokens = torch.sum(self.extend_lens).item()
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
out_cache_loc, token_to_kv_pool_state_backup = ( out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend( batch.alloc_paged_token_slots_extend(
prefix_lens, prefix_lens,
...@@ -423,19 +445,31 @@ class EAGLEWorker(TpModelWorker): ...@@ -423,19 +445,31 @@ class EAGLEWorker(TpModelWorker):
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.seq_lens, batch.seq_lens,
self.extend_lens,
self.num_new_pages_per_topk,
out_cache_loc, out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1], batch.req_to_token_pool.req_to_token.shape[1],
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.page_size, self.page_size,
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
) )
if self.page_size > 1 and self.topk > 1:
# Remove padded slots
out_cache_loc = out_cache_loc[
: num_seqs * self.topk * self.speculative_num_steps
]
batch.out_cache_loc = out_cache_loc batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
# Get forward batch
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -449,9 +483,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -449,9 +483,6 @@ class EAGLEWorker(TpModelWorker):
else: else:
# Initialize attention backend # Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch) self.draft_attn_backend.init_forward_metadata(forward_batch)
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
# Run forward steps # Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch) score_list, token_list, parents_list = self.draft_forward(forward_batch)
...@@ -504,6 +535,13 @@ class EAGLEWorker(TpModelWorker): ...@@ -504,6 +535,13 @@ class EAGLEWorker(TpModelWorker):
if self.hot_token_id is not None: if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index] topk_index = self.hot_token_id[topk_index]
out_cache_loc = out_cache_loc.reshape(
forward_batch.batch_size, self.topk, self.speculative_num_steps
)
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
self.speculative_num_steps, -1
)
# Return values # Return values
score_list: List[torch.Tensor] = [] score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = [] token_list: List[torch.Tensor] = []
...@@ -525,10 +563,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -525,10 +563,7 @@ class EAGLEWorker(TpModelWorker):
# Set inputs # Set inputs
forward_batch.input_ids = input_ids forward_batch.input_ids = input_ids
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1) forward_batch.out_cache_loc = out_cache_loc[i]
forward_batch.out_cache_loc = out_cache_loc[
:, self.topk * i : self.topk * (i + 1)
].flatten()
forward_batch.positions.add_(1) forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states spec_info.hidden_states = hidden_states
...@@ -586,7 +621,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -586,7 +621,7 @@ class EAGLEWorker(TpModelWorker):
if vocab_mask is not None: if vocab_mask is not None:
assert spec_info.grammar is not None assert spec_info.grammar is not None
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device) vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
# otherwise, this vocab mask will be the one from the previous extend stage # NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results # and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None batch.sampling_info.vocab_mask = None
...@@ -607,13 +642,13 @@ class EAGLEWorker(TpModelWorker): ...@@ -607,13 +642,13 @@ class EAGLEWorker(TpModelWorker):
] ]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards. # Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input batch.spec_info = res.draft_input
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch, can_run_cuda_graph return logits_output, res, model_worker_batch, can_run_cuda_graph
def add_logprob_values( def add_logprob_values(
...@@ -626,8 +661,16 @@ class EAGLEWorker(TpModelWorker): ...@@ -626,8 +661,16 @@ class EAGLEWorker(TpModelWorker):
logits_output = res.logits_output logits_output = res.logits_output
top_logprobs_nums = batch.top_logprobs_nums top_logprobs_nums = batch.top_logprobs_nums
token_ids_logprobs = batch.token_ids_logprobs token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits)
temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens]
logprobs = torch.nn.functional.log_softmax( logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1 logits_output.next_token_logits / temperatures, dim=-1
) )
batch_next_token_ids = res.verified_id batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
...@@ -662,7 +705,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -662,7 +705,7 @@ class EAGLEWorker(TpModelWorker):
pt = 0 pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist() next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist() verified_ids = batch_next_token_ids.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req): for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
for _ in range(num_tokens): for _ in range(num_tokens):
if req.return_logprob: if req.return_logprob:
req.output_token_logprobs_val.append(next_token_logprobs[pt]) req.output_token_logprobs_val.append(next_token_logprobs[pt])
...@@ -690,7 +733,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -690,7 +733,6 @@ class EAGLEWorker(TpModelWorker):
hidden_states: Hidden states from the target model forward hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward. next_token_ids: Next token ids generated from the target forward.
""" """
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
batch.spec_info = EagleDraftInput( batch.spec_info = EagleDraftInput(
hidden_states=hidden_states, hidden_states=hidden_states,
verified_id=next_token_ids, verified_id=next_token_ids,
...@@ -701,7 +743,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -701,7 +743,6 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu seq_lens_cpu_cache=seq_lens_cpu
) )
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -724,9 +765,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -724,9 +765,7 @@ class EAGLEWorker(TpModelWorker):
batch, batch,
self.speculative_num_steps, self.speculative_num_steps,
) )
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -790,3 +829,47 @@ def load_token_map(token_map_path: str) -> List[int]: ...@@ -790,3 +829,47 @@ def load_token_map(token_map_path: str) -> List[int]:
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True) hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32) return torch.tensor(hot_token_id, dtype=torch.int32)
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens,
speculative_num_steps: int,
):
prefix_lens = seq_lens
seq_lens = prefix_lens + speculative_num_steps
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_large_top_k(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
speculative_num_steps: int,
topk: int,
page_size: int,
):
prefix_lens = seq_lens
last_page_lens = prefix_lens % page_size
num_new_pages_per_topk = (
last_page_lens + speculative_num_steps + page_size - 1
) // page_size
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
page_size * topk
)
extend_lens = seq_lens - prefix_lens
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
...@@ -441,5 +441,71 @@ class TestEAGLEServerTriton(TestEAGLEServer): ...@@ -441,5 +441,71 @@ class TestEAGLEServerTriton(TestEAGLEServer):
) )
class TestEAGLEServerPageSize(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
6,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
"--attention-backend",
"flashinfer",
],
)
class TestEAGLEServerPageSizeTopk(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
8,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
"--attention-backend",
"flashinfer",
],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment