Unverified Commit 5f1ab327 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[EAGLE] Refactor code for page size > 1 & more simplifications (#7163)

parent 7df7c679
......@@ -1049,14 +1049,13 @@ class FlashInferMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
assert forward_batch.spec_info is not None
......
......@@ -789,6 +789,7 @@ class FlashInferMLAMultiStepDraftBackend:
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
def common_template(
self,
......@@ -809,14 +810,13 @@ class FlashInferMLAMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
assert forward_batch.spec_info is not None
......
......@@ -784,14 +784,13 @@ class TritonMultiStepDraftBackend:
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
for i in range(self.speculative_num_steps):
......
......@@ -294,6 +294,19 @@ class MHATokenToKVPool(KVCache):
for _ in range(self.layer_num)
]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
dtype=torch.uint64,
device=self.device,
)
self.data_strides = torch.tensor(
[
np.prod(x.shape[1:]) * x.dtype.itemsize
for x in self.k_buffer + self.v_buffer
],
device=self.device,
)
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
......@@ -451,6 +464,16 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
self.data_ptrs,
self.data_strides,
tgt_loc,
src_loc,
len(tgt_loc),
next_power_of_2(len(tgt_loc)),
)
@triton.jit
def set_mla_kv_buffer_kernel(
......@@ -741,3 +764,41 @@ class DoubleSparseTokenToKVPool(KVCache):
def transfer_per_layer(self, indices, flat_data, layer_id):
pass
@triton.jit
def copy_all_layer_kv_cache(
data_ptrs,
strides,
tgt_loc_ptr,
src_loc_ptr,
num_locs,
num_locs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
bid = tl.program_id(0)
stride = tl.load(strides + bid)
data_ptr = tl.load(data_ptrs + bid)
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
num_locs_offset = tl.arange(0, num_locs_upper)
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
# because this copy is an inplace operation.
num_loop = tl.cdiv(stride, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
value = tl.load(
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
)
tl.store(
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
value,
mask=mask,
)
......@@ -35,11 +35,17 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput,
EagleVerifyOutput,
assign_draft_cache_locs,
fast_topk,
generate_token_bitmask,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
is_cuda,
next_power_of_2,
)
if is_cuda():
from sgl_kernel import segment_packbits
......@@ -152,6 +158,12 @@ class EAGLEWorker(TpModelWorker):
self.init_attention_backend()
self.init_cuda_graphs()
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
if self.server_args.attention_backend == "flashinfer":
......@@ -254,7 +266,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
)
# Capture extend
......@@ -269,7 +281,7 @@ class EAGLEWorker(TpModelWorker):
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
)
@property
......@@ -290,7 +302,6 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
......@@ -366,14 +377,21 @@ class EAGLEWorker(TpModelWorker):
)
# Allocate cache locations
# Layout of the out_cache_loc
# [ topk 0 ] [ topk 1 ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
if self.page_size == 1:
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
)
else:
if self.topk == 1:
prefix_lens = batch.seq_lens
seq_lens = prefix_lens + self.speculative_num_steps
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
)
extend_num_tokens = num_seqs * self.speculative_num_steps
else:
# In this case, the last partial page needs to be duplicated.
......@@ -386,29 +404,33 @@ class EAGLEWorker(TpModelWorker):
# "x" means speculative draft tokens
# "." means padded tokens
# TODO: fuse these ops
prefix_lens = batch.seq_lens
last_page_lens = prefix_lens % self.page_size
num_new_pages = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens = (
prefix_lens // self.page_size * self.page_size
+ num_new_pages * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
raise NotImplementedError(
"page_size > 1 and top_k > 1 are not supported."
# TODO(lmzheng): The current implementation is still a fake support
# for page size > 1. In the `assign_draft_cache_locs` below,
# we directly move the indices instead of the real kv cache.
# This only works when the kernel backend runs with page size = 1.
# If the kernel backend runs with page size > 1, we need to
# duplicate the real KV cache. The overhead of duplicating KV
# cache seems okay because the draft KV cache only has one layer.
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
(
prefix_lens,
seq_lens,
last_loc,
self.num_new_pages_per_topk,
self.extend_lens,
) = get_last_loc_large_page_size_large_top_k(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
self.topk,
self.page_size,
)
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments
# 2. Modify generate_draft_decode_kv_indices accordingly
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
# TODO(lmzheng): remove this device sync
extend_num_tokens = torch.sum(self.extend_lens).item()
out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend(
prefix_lens,
......@@ -423,19 +445,31 @@ class EAGLEWorker(TpModelWorker):
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
self.extend_lens,
self.num_new_pages_per_topk,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
self.page_size,
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
)
if self.page_size > 1 and self.topk > 1:
# Remove padded slots
out_cache_loc = out_cache_loc[
: num_seqs * self.topk * self.speculative_num_steps
]
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
# Get forward batch
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -504,6 +538,13 @@ class EAGLEWorker(TpModelWorker):
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
out_cache_loc = out_cache_loc.reshape(
forward_batch.batch_size, self.topk, self.speculative_num_steps
)
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
self.speculative_num_steps, -1
)
# Return values
score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = []
......@@ -525,10 +566,7 @@ class EAGLEWorker(TpModelWorker):
# Set inputs
forward_batch.input_ids = input_ids
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
forward_batch.out_cache_loc = out_cache_loc[
:, self.topk * i : self.topk * (i + 1)
].flatten()
forward_batch.out_cache_loc = out_cache_loc[i:]
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
......@@ -586,7 +624,7 @@ class EAGLEWorker(TpModelWorker):
if vocab_mask is not None:
assert spec_info.grammar is not None
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
# otherwise, this vocab mask will be the one from the previous extend stage
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None
......@@ -607,13 +645,13 @@ class EAGLEWorker(TpModelWorker):
]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch, can_run_cuda_graph
def add_logprob_values(
......@@ -626,8 +664,16 @@ class EAGLEWorker(TpModelWorker):
logits_output = res.logits_output
top_logprobs_nums = batch.top_logprobs_nums
token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits)
temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens]
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
logits_output.next_token_logits / temperatures, dim=-1
)
batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
......@@ -662,7 +708,7 @@ class EAGLEWorker(TpModelWorker):
pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
for _ in range(num_tokens):
if req.return_logprob:
req.output_token_logprobs_val.append(next_token_logprobs[pt])
......@@ -690,7 +736,6 @@ class EAGLEWorker(TpModelWorker):
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
......@@ -701,7 +746,6 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -724,9 +768,7 @@ class EAGLEWorker(TpModelWorker):
batch,
self.speculative_num_steps,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -790,3 +832,47 @@ def load_token_map(token_map_path: str) -> List[int]:
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32)
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens,
speculative_num_steps: int,
):
prefix_lens = seq_lens
seq_lens = prefix_lens + speculative_num_steps
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_large_top_k(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
speculative_num_steps: int,
topk: int,
page_size: int,
):
prefix_lens = seq_lens
last_page_lens = prefix_lens % page_size
num_new_pages_per_topk = (
last_page_lens + speculative_num_steps + page_size - 1
) // page_size
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
page_size * topk
)
extend_lens = seq_lens - prefix_lens
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
......@@ -441,5 +441,71 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
class TestEAGLEServerPageSize(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
6,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
"--attention-backend",
"flashinfer",
],
)
class TestEAGLEServerPageSizeTopk(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
8,
"--speculative-num-draft-tokens",
64,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
"--attention-backend",
"flashinfer",
],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment