Unverified Commit b26bc86b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support page size > 1 + eagle (#4908)

parent 5ec5eaf7
......@@ -33,6 +33,7 @@ runtime_common = [
"prometheus-client>=0.20.0",
"psutil",
"pydantic",
"pynvml",
"python-multipart",
"pyzmq>=25.1.2",
"soundfile==0.13.1",
......
......@@ -14,7 +14,6 @@ from functools import partial
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
import triton
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
self.page_size = model_runner.page_size
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
......@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
)
assert forward_batch.spec_info is not None
......@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
)
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
......
......@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
return req_pool_indices
def alloc_token_slots(self, num_tokens: int):
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
......@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.tree_cache.pretty_print()
raise RuntimeError(error_msg)
return out_cache_loc
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_extend(
self,
......@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
backup_state: bool = False,
):
if (
self.token_to_kv_pool_allocator.available_size()
......@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens, seq_lens, last_loc, extend_num_tokens
)
......@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
logger.error(error_msg)
raise RuntimeError(error_msg)
return out_cache_loc
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
backup_state: bool = False,
):
if (
self.token_to_kv_pool_allocator.available_size()
......@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.tree_cache.evict(
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
if out_cache_loc is None:
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
......@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
logger.error(error_msg)
raise RuntimeError(error_msg)
return out_cache_loc
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
......
......@@ -1110,7 +1110,7 @@ class Scheduler(
)
if memory_leak:
msg = (
"KV cache pool leak detected! "
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
......@@ -1121,7 +1121,7 @@ class Scheduler(
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
msg = (
"Memory pool leak detected!"
"req_to_token_pool memory leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
)
......
......@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_slots
def restore_state(self, free_slots):
self.free_slots = free_slots
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange(
......
......@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator:
next_power_of_2(extend_num_tokens),
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
merged_value = self.ret_values.item()
num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages):
......@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator:
self.page_size,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.item()
if num_new_pages > len(self.free_pages):
return None
......@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator:
else:
self.free_group.append(free_index)
if self.debug_mode:
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
......@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator:
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_pages
def restore_state(self, free_pages):
self.free_pages = free_pages
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
......
......@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if capture_bs is None:
if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
capture_bs = list(range(1, 33)) + range(40, 161, 16)
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
else:
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
capture_bs = (
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
)
if _is_hip:
capture_bs += [i * 8 for i in range(21, 33)]
capture_bs += list(range(160, 257, 8))
if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
......
......@@ -17,7 +17,7 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......
......@@ -15,6 +15,7 @@
import argparse
import dataclasses
import json
import logging
import os
import random
......@@ -132,9 +133,9 @@ class ServerArgs:
# Speculative decoding
speculative_algorithm: Optional[str] = None
speculative_draft_model_path: Optional[str] = None
speculative_num_steps: int = 5
speculative_eagle_topk: int = 4
speculative_num_draft_tokens: int = 8
speculative_num_steps: Optional[int] = None
speculative_eagle_topk: Optional[int] = None
speculative_num_draft_tokens: Optional[int] = None
speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
......@@ -313,12 +314,29 @@ class ServerArgs:
or self.speculative_algorithm == "EAGLE3"
):
if self.max_running_requests is None:
self.max_running_requests = 32
self.max_running_requests = 48
self.disable_overlap_schedule = True
logger.info(
"Overlap scheduler is disabled because of using "
"eagle speculative decoding."
)
# Auto choose parameters
if self.speculative_num_steps is None:
assert (
self.speculative_eagle_topk is None
and self.speculative_num_draft_tokens is None
)
(
self.speculative_num_steps,
self.speculative_eagle_topk,
self.speculative_num_draft_tokens,
) = auto_choose_speculative_params(self)
if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
......@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
raise ValueError(self.help)
def auto_choose_speculative_params(self: ServerArgs):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
if self.decrypted_config_file:
config_path = self.decrypted_config_file
else:
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise ValueError(f"{config_path} is not found.")
config = json.load(open(config_path))
arch = config.get("architectures", ["Unknown"])[0]
if arch in ["LlamaForCausalLM"]:
# The default value for llama
return (5, 4, 8)
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
# The default value for deepseek
return (5, 4, 8)
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
return (5, 4, 8)
else:
# The default value for all other models
return (5, 4, 8)
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
......@@ -10,11 +11,15 @@ import triton.language as tl
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.managers.schedule_batch import (
ScheduleBatch,
get_last_loc,
global_server_args_dict,
)
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import is_cuda_available, is_hip
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
if is_cuda_available():
from sgl_kernel import (
......@@ -34,6 +39,9 @@ import logging
logger = logging.getLogger(__name__)
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
@dataclass
class EagleDraftInput:
# The inputs for decode
......@@ -93,7 +101,7 @@ class EagleDraftInput:
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
self.positions,
new_verified_id,
triton.next_power_of_2(speculative_num_steps + 1),
next_power_of_2(speculative_num_steps + 1),
)
batch.seq_lens_sum = sum(seq_lens_cpu)
......@@ -225,18 +233,34 @@ class EagleVerifyInput:
CaptureHiddenMode.FULL,
)
def prepare_for_verify(self, batch: ScheduleBatch):
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
batch.input_ids = self.draft_token
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
if page_size == 1:
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
end_offset = batch.seq_lens + self.draft_token_num
else:
prefix_lens = batch.seq_lens
end_offset = prefix_lens + self.draft_token_num
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
prefix_lens, end_offset, last_loc, len(batch.input_ids)
)
self.last_loc = last_loc
bs = batch.batch_size()
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
end_offset,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
next_power_of_2(bs),
)
def generate_attn_arg_prefill(
......@@ -282,6 +306,7 @@ class EagleVerifyInput:
batch: ScheduleBatch,
logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
) -> torch.Tensor:
"""
Verify and find accepted tokens based on logits output and batch
......@@ -305,6 +330,7 @@ class EagleVerifyInput:
)
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
# Apply penalty
if sampling_info.penalizer_orchestrator.is_required:
# This is a relaxed version of penalties for speculative decoding.
linear_penalty = torch.zeros(
......@@ -317,6 +343,7 @@ class EagleVerifyInput:
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
)
# Sample tokens
if batch.sampling_info.is_all_greedy:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num)
......@@ -378,13 +405,24 @@ class EagleVerifyInput:
deterministic=True,
)
if SIMULATE_ACC_LEN:
# Do simulation
accept_index = _generate_simulated_accept_index(
accept_index=accept_index,
predict=predict, # mutable
accept_length=accept_length, # mutable
simulate_acc_len=SIMULATE_ACC_LEN,
bs=bs,
spec_steps=self.spec_steps,
)
new_accept_index = []
unfinished_index = []
accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist()
has_finished = False
# iterate every accepted token and check if req has finished after append the token
# Iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
new_accept_index_ = []
......@@ -407,13 +445,28 @@ class EagleVerifyInput:
unfinished_index.append(i)
req.spec_verify_ct += 1
if has_finished:
accept_length = (accept_index != -1).sum(dim=1) - 1
# Free the KV cache for unaccepted tokens
accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
if page_size != 1:
align_evict_mask_to_page_size[len(batch.seq_lens),](
batch.seq_lens,
evict_mask,
page_size,
self.draft_token_num,
next_power_of_2(self.draft_token_num),
)
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
# Construct EagleVerifyOutput
if not has_finished:
accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask]
token_to_kv_pool_allocator.free(mem_need_free_idx)
batch.out_cache_loc = batch.out_cache_loc[accept_index]
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
......@@ -422,7 +475,7 @@ class EagleVerifyInput:
batch.seq_lens + accept_length + 1,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist()
......@@ -443,13 +496,6 @@ class EagleVerifyInput:
accepeted_indices=accept_index,
)
else:
accept_length = (accept_index != -1).sum(dim=1) - 1
accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask]
token_to_kv_pool_allocator.free(mem_need_free_idx)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
......@@ -457,7 +503,7 @@ class EagleVerifyInput:
batch.seq_lens + accept_length + 1,
batch.out_cache_loc[accept_index],
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist()
......@@ -465,20 +511,21 @@ class EagleVerifyInput:
draft_input = EagleDraftInput()
if len(new_accept_index) > 0:
new_accept_index = torch.tensor(new_accept_index, device="cuda")
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
draft_input.hidden_states = batch.spec_info.hidden_states[
new_accept_index
]
draft_input.verified_id = predict[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index]
draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
]
draft_input.accept_length = accept_length[unfinished_index_device]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
unfinished_index
unfinished_index_device
]
draft_input.req_pool_indices_for_draft_extend = (
batch.req_pool_indices[unfinished_index]
batch.req_pool_indices[unfinished_index_device]
)
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
......@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
page_size: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0)
kv_start = tl.load(seq_lens + pid)
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
if page_size == 1 or topk == 1:
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
else:
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
num_new_page = (
last_page_len + speculative_num_steps + page_size - 1
) // page_size
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
for i in range(num_loop):
......@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
tl.store(kv_indptr + zid, base + zid * iters)
@triton.jit
def align_evict_mask_to_page_size(
seq_lens,
evict_mask,
page_size: tl.constexpr,
num_draft_tokens: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
t_range = tl.arange(0, BLOCK_SIZE)
bid = tl.program_id(axis=0)
seq_len = tl.load(seq_lens + bid)
io_mask = t_range < num_draft_tokens
mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
num_trues = tl.sum(mask_row)
num_false = num_draft_tokens - num_trues
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
tl.store(evict_mask + bid * num_draft_tokens + i, False)
@torch.compile(dynamic=True)
def select_top_k_tokens(
i: int,
......@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index(
accept_index,
predict,
accept_length,
simulate_acc_len,
bs,
spec_steps,
):
simulate_acc_len_float = float(simulate_acc_len)
simulated_values = torch.normal(
mean=simulate_acc_len_float,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
simulate_acc_len = int(simulated_values.round().item())
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
sim_accept_index = torch.full(
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
)
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
simulate_acc_len, device=accept_index.device
)
accept_length.fill_(simulate_acc_len - 1)
predict.fill_(100) # some legit token id
return sim_accept_index
......@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
......@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker):
self.gpu_id = gpu_id
self.device = server_args.device
self.target_worker = target_worker
self.page_size = server_args.page_size
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
......@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker):
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info, to_free_cache_loc = self.draft(batch)
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
# If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None:
with self.draft_tp_context(self.draft_model_runner.tp_group):
......@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker):
)
# Allocate cache locations
out_cache_loc = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps
)
if self.page_size == 1:
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
)
else:
if self.topk == 1:
prefix_lens = batch.seq_lens
seq_lens = prefix_lens + self.speculative_num_steps
extend_num_tokens = num_seqs * self.speculative_num_steps
else:
# In this case, the last partial page needs to be duplicated.
# KV cache layout in batch.req_to_token_pool.req_to_token:
#
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
# prefix top-k = 0 tok-k = 1 top-k = 2
#
# "-" means prefix tokens
# "x" means speculative draft tokens
# "." means padded tokens
# TODO: fuse these ops
prefix_lens = batch.seq_lens
last_page_lens = prefix_lens % self.page_size
num_new_pages = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens = (
prefix_lens // self.page_size * self.page_size
+ num_new_pages * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
raise NotImplementedError(
"page_size > 1 and top_k > 1 are not supported."
)
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments
# 2. Modify generate_draft_decode_kv_indices accordingly
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend(
prefix_lens,
seq_lens,
last_loc,
extend_num_tokens,
backup_state=True,
)
)
assign_draft_cache_locs[(num_seqs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
......@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker):
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
self.page_size,
)
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
......@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker):
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
ret = EagleVerifyInput.create(
spec_info.verified_id,
score_list,
......@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker):
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
)
return ret, out_cache_loc
return ret
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
......@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker):
return score_list, token_list, parents_list
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch)
spec_info.prepare_for_verify(batch, self.page_size)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch()
......@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker):
self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify(
batch, logits_output, self.token_to_kv_pool_allocator
batch,
logits_output,
self.token_to_kv_pool_allocator,
self.page_size,
)
# Post process based on verified outputs.
......
......@@ -76,11 +76,14 @@ def is_in_ci():
if is_in_ci():
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
)
else:
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157"
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
)
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
......@@ -1009,6 +1012,9 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
class CustomTestCase(unittest.TestCase):
pass
"""
def _callTestMethod(self, method):
max_retry = int(
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
......@@ -1017,3 +1023,4 @@ class CustomTestCase(unittest.TestCase):
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry,
)
"""
......@@ -18,7 +18,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
pip install sgl-kernel==0.0.5.post4 --force-reinstall
pip install torch_memory_saver
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm torchaudio
# For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12
......
......@@ -26,7 +26,7 @@ suites = {
TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22),
TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 447),
TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"),
TestFile("test_fp8_kernel.py", 2),
TestFile("test_embedding_openai_server.py", 36),
......
......@@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase):
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
server_info = requests.get(self.base_url + "/get_server_info").json()
avg_spec_accept_length = server_info["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 3.5)
speculative_eagle_topk = server_info["speculative_eagle_topk"]
if speculative_eagle_topk == 1:
self.assertGreater(avg_spec_accept_length, 2.5)
else:
self.assertGreater(avg_spec_accept_length, 3.5)
# Wait a little bit so that the memory check happens.
time.sleep(4)
......@@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
class TestEAGLEServerPageSize(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
6,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
],
)
if __name__ == "__main__":
unittest.main()
......@@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
print(f"{server_info=}")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment