Unverified Commit bc1534ff authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134)


Co-authored-by: default avatarSehoon Kim <kssteven418@gmail.com>
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatarSehoon Kim <sehoon@x.ai>
parent 3a391812
...@@ -95,7 +95,7 @@ jobs: ...@@ -95,7 +95,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100] range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-48, 48-100]
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v3
......
...@@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize. ...@@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
""" """
import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch import torch
import triton import triton
import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -37,7 +35,7 @@ if is_flashinfer_available(): ...@@ -37,7 +35,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode from flashinfer.decode import _get_range_buf, get_seq_lens
class WrapperDispatch(Enum): class WrapperDispatch(Enum):
...@@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend):
): ):
super().__init__() super().__init__()
self.is_multimodal = model_runner.model_config.is_multimodal
# Parse constants # Parse constants
self.decode_use_tensor_cores = should_use_tensor_core( self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype, kv_cache_dtype=model_runner.kv_cache_dtype,
...@@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
) )
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
assert not ( assert not (
model_runner.sliding_window_size is not None model_runner.sliding_window_size is not None
...@@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend):
device=model_runner.device, device=model_runner.device,
) )
self.workspace_buffer = global_workspace_buffer self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None: if kv_indptr_buf is None:
self.kv_indptr = [ self.kv_indptr = [
...@@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend):
) )
) )
self.prefill_wrappers_verify.append( self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
)
) )
self.decode_wrappers.append( self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper( BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, self.workspace_buffer,
...@@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend):
if not skip_prefill: if not skip_prefill:
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self model_runner, self
) ) # for verify
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self) self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
# Other metadata # Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} self.prefill_cuda_graph_metadata = {} # For verify
self.draft_extend_cuda_graph_metadata = {} # For draft extend
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
...@@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend):
], ],
) )
) )
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices, req_pool_indices,
...@@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
) )
self.decode_cuda_graph_metadata[bs] = decode_wrappers self.decode_cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = DecodeMetadata(decode_wrappers) self.forward_metadata = DecodeMetadata(decode_wrappers)
for i in range(self.num_wrappers):
decode_wrappers[i].begin_forward = partial(
fast_decode_plan, decode_wrappers[i]
)
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
prefill_wrappers = [] prefill_wrappers = []
for i in range(self.num_wrappers): for i in range(self.num_wrappers):
...@@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False, causal=False,
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, logits_soft_cap=logits_soft_cap,
) )
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
...@@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode: ...@@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode:
bs = len(req_pool_indices) bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda" if wrapper.is_cuda_graph_enabled:
) # Directly write to the cuda graph input buffer
kv_indices = wrapper._paged_kv_indices_buf
else:
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
...@@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode:
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
else: else:
assert isinstance(spec_info, EagleDraftInput)
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
wrapper.begin_forward( wrapper.begin_forward(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
...@@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill:
def update( def update(
self, self,
req_pool_indices: torch.Tnesor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
...@@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
def update_single_wrapper( def update_single_wrapper(
self, self,
req_pool_indices: torch.Tnesor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
...@@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill:
kv_indices, kv_indices,
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
custom_mask = None custom_mask = None
...@@ -954,7 +961,10 @@ class FlashInferMultiStepDraftBackend: ...@@ -954,7 +961,10 @@ class FlashInferMultiStepDraftBackend:
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
def common_template( def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int self,
forward_batch: ForwardBatch,
kv_indices_buffer: torch.Tensor,
call_fn: Callable,
): ):
num_seqs = forward_batch.batch_size num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs bs = self.topk * num_seqs
...@@ -1042,17 +1052,15 @@ class FlashInferMultiStepDraftBackend: ...@@ -1042,17 +1052,15 @@ class FlashInferMultiStepDraftBackend:
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info, spec_info=forward_batch.spec_info,
) )
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
forward_batch.batch_size
][0]
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch): def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph( self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size, bs,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
seq_lens_sum=-1, seq_lens_sum=-1,
...@@ -1113,6 +1121,11 @@ def should_use_tensor_core( ...@@ -1113,6 +1121,11 @@ def should_use_tensor_core(
return False return False
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global_override_indptr_cpu = None
def fast_decode_plan( def fast_decode_plan(
self, self,
indptr: torch.Tensor, indptr: torch.Tensor,
...@@ -1142,6 +1155,9 @@ def fast_decode_plan( ...@@ -1142,6 +1155,9 @@ def fast_decode_plan(
if logits_soft_cap is None: if logits_soft_cap is None:
logits_soft_cap = 0.0 logits_soft_cap = 0.0
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
if self.is_cuda_graph_enabled: if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size: if batch_size != self._fixed_batch_size:
raise ValueError( raise ValueError(
...@@ -1154,7 +1170,7 @@ def fast_decode_plan( ...@@ -1154,7 +1170,7 @@ def fast_decode_plan(
raise ValueError( raise ValueError(
"The size of indices should be less than or equal to the allocated buffer" "The size of indices should be less than or equal to the allocated buffer"
) )
# Skip these copies # Skip these copies because we directly write to them during prepartion
# self._paged_kv_indptr_buf.copy_(indptr) # self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices # self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len) # self._paged_kv_last_page_len_buf.copy_(last_page_len)
...@@ -1162,6 +1178,7 @@ def fast_decode_plan( ...@@ -1162,6 +1178,7 @@ def fast_decode_plan(
self._paged_kv_indptr_buf = indptr self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len self._paged_kv_last_page_len_buf = last_page_len
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type: if not q_data_type:
...@@ -1184,27 +1201,55 @@ def fast_decode_plan( ...@@ -1184,27 +1201,55 @@ def fast_decode_plan(
) )
self.last_page_len = torch.ones(32768, dtype=torch.int32) self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data indptr_host = (
empty_kv_cache = self.empty_kv_cache global_override_indptr_cpu
stream = torch.cuda.current_stream() if global_override_indptr_cpu is not None
self._cached_module.plan( else indptr.cpu()
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr.to("cpu"),
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
empty_q_data,
empty_kv_cache,
stream.cuda_stream,
) )
if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens(
indptr_host, self.last_page_len[:batch_size], page_size
)
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
torch.cuda.current_stream().cuda_stream,
)
else:
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
self.empty_q_data,
self.empty_kv_cache,
torch.cuda.current_stream().cuda_stream,
)
self._pos_encoding_mode = pos_encoding_mode self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left self._window_left = window_left
self._logits_soft_cap = logits_soft_cap self._logits_soft_cap = logits_soft_cap
......
...@@ -578,10 +578,12 @@ class TritonMultiStepDraftBackend: ...@@ -578,10 +578,12 @@ class TritonMultiStepDraftBackend:
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch): def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph( self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size, bs,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
seq_lens_sum=-1, seq_lens_sum=-1,
......
...@@ -396,16 +396,10 @@ class CudaGraphRunner: ...@@ -396,16 +396,10 @@ class CudaGraphRunner:
run_once() run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global global_graph_memory_pool global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream): with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once() out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global_graph_memory_pool = graph.pool() global_graph_memory_pool = graph.pool()
return graph, out return graph, out
......
...@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess( ...@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess(
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
parent_list = torch.cat(parents_list[:-1], dim=1)
if len(parents_list) > 1:
parent_list = torch.cat(parents_list[:-1], dim=1)
else:
batch_size = parents_list[0].shape[0]
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
return parent_list, top_scores_index, draft_tokens return parent_list, top_scores_index, draft_tokens
......
from __future__ import annotations from __future__ import annotations
import bisect import bisect
import time
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable
import torch import torch
...@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner: ...@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner:
run_once() run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph( with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream graph, pool=get_global_graph_memory_pool(), stream=stream
): ):
out = run_once() out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
set_global_graph_memory_pool(graph.pool()) set_global_graph_memory_pool(graph.pool())
return graph, out return graph, out
...@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner:
# Attention backend # Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch forward_batch, forward_batch.batch_size
) )
# Replay # Replay
......
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List from typing import TYPE_CHECKING, List
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -62,6 +62,7 @@ class EagleDraftInput: ...@@ -62,6 +62,7 @@ class EagleDraftInput:
batch.input_ids[pt : pt + extend_len] = torch.concat( batch.input_ids[pt : pt + extend_len] = torch.concat(
(input_ids[1:], self.verified_id[i].reshape(1)) (input_ids[1:], self.verified_id[i].reshape(1))
) )
pt += extend_len
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
assert self.verified_id.numel() == batch.out_cache_loc.shape[0] assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
......
import logging import logging
import os import os
import time import time
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner, EAGLEDraftCudaGraphRunner,
...@@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk, fast_topk,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import get_available_gpu_memory from sglang.srt.utils import get_available_gpu_memory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker): ...@@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker):
nccl_port: int, nccl_port: int,
target_worker: TpModelWorker, target_worker: TpModelWorker,
): ):
# Parse arguments
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.padded_static_len = self.speculative_num_steps + 1
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
self.target_worker = target_worker
# Override context length with target model's context length # Override context length with target model's context length
server_args.context_length = target_worker.model_runner.model_config.context_len server_args.context_length = target_worker.model_runner.model_config.context_len
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
# Do not capture cuda graph in `super().__init__()` # Do not capture cuda graph in `super().__init__()`
# We will capture it later # It will be captured later.
backup_disable_cuda_graph = server_args.disable_cuda_graph backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True server_args.disable_cuda_graph = True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Lossy optimization by using hot tokens # Load hot token ids
if server_args.speculative_token_map is not None: if server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map) self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = ( server_args.json_model_override_args = (
...@@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker):
else: else:
self.hot_token_id = None self.hot_token_id = None
# We share the allocator with a target worker. Draft/target worker # Init draft worker
# owns its own KV cache.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Init target worker
super().__init__( super().__init__(
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
...@@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker):
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
) )
self.target_worker = target_worker
# Parse arguments
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.server_args = server_args
self.use_nan_detection = self.server_args.enable_nan_detection
self.device = self.model_runner.device
self.gpu_id = self.model_runner.gpu_id
# Share the embedding and lm_head # Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head() embed, head = self.target_worker.model_runner.model.get_embed_and_head()
...@@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker):
backup_disable_cuda_graph backup_disable_cuda_graph
) )
self.init_attention_backend()
self.init_cuda_graphs()
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
if server_args.attention_backend == "flashinfer": if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend, FlashInferMultiStepDraftBackend,
) )
...@@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
elif server_args.attention_backend == "triton": elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import ( from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend, TritonMultiStepDraftBackend,
) )
...@@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
) )
else: else:
raise ValueError( raise ValueError(
f"EAGLE is not supportted in attention backend {server_args.attention_backend}" f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
) )
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs()
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
...@@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker): ...@@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input batch.spec_info = res.draft_input
if batch.return_logprob:
# Compute output logprobs using the sampler.
num_tokens_per_req = [
accept + 1 for accept in res.accept_length_per_req_cpu
]
self.target_worker.model_runner.update_output_logprobs(
logits_output,
batch.sampling_info,
batch.top_logprobs_nums,
batch.token_ids_logprobs,
res.verified_id,
# +1 for bonus token.
num_tokens_per_req=num_tokens_per_req,
)
# Add output logprobs to the request.
pt = 0
# NOTE: tolist() of these values are skipped when output is processed
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
verified_ids = res.verified_id.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for _ in range(num_tokens):
if req.return_logprob:
token_id = verified_ids[pt]
req.output_token_logprobs_val.append(next_token_logprobs[pt])
req.output_token_logprobs_idx.append(token_id)
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(
res.logits_output.next_token_top_logprobs_val[pt]
)
req.output_top_logprobs_idx.append(
res.logits_output.next_token_top_logprobs_idx[pt]
)
pt += 1
return logits_output, res, model_worker_batch return logits_output, res, model_worker_batch
def forward_draft_extend( def forward_draft_extend(
...@@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker):
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
forward_batch.return_logprob = False
logits_output = self.draft_model_runner.forward(forward_batch) logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert isinstance(forward_batch.spec_info, EagleDraftInput)
...@@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# We don't need logprob for this extend. # We don't need logprob for this extend.
original_return_logprob = batch.return_logprob
batch.return_logprob = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
...@@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker):
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.return_logprob = original_return_logprob
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup batch.seq_lens = seq_lens_backup
...@@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker):
draft_input.hidden_states = logits_output.hidden_states draft_input.hidden_states = logits_output.hidden_states
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput): def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.use_nan_detection: if self.enable_nan_detection:
logits = logits_output.next_token_logits logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)): if torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.") logger.warning("Detected errors during sampling! NaN in the logits.")
......
...@@ -165,7 +165,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -165,7 +165,7 @@ class TestBenchServing(unittest.TestCase):
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
f'accept_length : {res["accept_length"]:.2f} \n' f'accept_length : {res["accept_length"]:.2f} \n'
) )
self.assertLess(res["median_e2e_latency_ms"], 1100) self.assertLess(res["median_e2e_latency_ms"], 900)
self.assertGreater(res["accept_length"], 2.99) self.assertGreater(res["accept_length"], 2.99)
def test_moe_offline_throughput_default(self): def test_moe_offline_throughput_default(self):
......
import json
import multiprocessing as mp import multiprocessing as mp
import os import os
import random import random
import threading import threading
import time import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional from typing import List, Optional
import numpy as np
import requests import requests
import torch import torch
...@@ -21,6 +25,7 @@ from sglang.test.test_utils import ( ...@@ -21,6 +25,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
run_logprob_check,
) )
torch_dtype = torch.float16 torch_dtype = torch.float16
...@@ -260,11 +265,132 @@ class TestEAGLEServer(unittest.TestCase): ...@@ -260,11 +265,132 @@ class TestEAGLEServer(unittest.TestCase):
server_info = requests.get(self.base_url + "/get_server_info") server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}") print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.9) self.assertGreater(avg_spec_accept_length, 3.5)
# Wait a little bit so that the memory check happens. # Wait a little bit so that the memory check happens.
time.sleep(4) time.sleep(4)
def test_logprob_start_len(self):
logprob_start_len = 4
new_tokens = 4
prompts = [
"I have a very good idea on",
"Today is a sunndy day and",
]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": new_tokens,
},
"return_logprob": True,
"top_logprobs_num": 5,
"logprob_start_len": logprob_start_len,
},
)
response_json = response.json()
print(json.dumps(response_json, indent=2))
for res in response_json:
self.assertEqual(
res["meta_info"]["prompt_tokens"],
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
)
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
def test_logprob_match(self):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
def run_generate(
prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1
):
if isinstance(prompt, str):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
response = requests.post(
self.base_url + "/generate",
json={
**prompt_kwargs,
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
return response.json()
prompt = "I have a very good idea on how to"
gen = run_generate(prompt, return_logprob=True, logprob_start_len=0)
output_logprobs = np.array(
[x[0] for x in gen["meta_info"]["output_token_logprobs"]]
)
num_prompts_tokens = gen["meta_info"]["prompt_tokens"]
input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]]
output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]]
new_prompt = input_tokens + output_tokens
score = run_generate(
new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0
)
output_logprobs_score = np.array(
[
x[0]
for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:]
]
)
print(f"{output_logprobs[-10:]=}")
print(f"{output_logprobs_score[-10:]=}")
diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff)
self.assertLess(max_diff, 0.25)
def test_logprob_mixed(self):
args = []
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
# Llama 2 context length seems to be only 2k, so we can only test small length.
for input_len in [200, 500, 1000, 2000]:
for output_len in [4, 8]:
for logprob_start_len in [0, 100, 300, 800, 1998]:
for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]:
if logprob_start_len >= input_len:
continue
args.append(
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
)
)
random.shuffle(args)
func = partial(run_logprob_check, self)
with ThreadPoolExecutor(8) as executor:
list(executor.map(func, args))
class TestEAGLERetract(TestEAGLEServer): class TestEAGLERetract(TestEAGLEServer):
@classmethod @classmethod
......
...@@ -143,11 +143,11 @@ class TestGPTQModelDynamic(unittest.TestCase): ...@@ -143,11 +143,11 @@ class TestGPTQModelDynamic(unittest.TestCase):
print(f"result = `{result}`") print(f"result = `{result}`")
assert "paris" in result["text"].lower() self.assertIn("paris", result["text"].lower())
throughput = max_tokens / (tok - tic) throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s") print(f"Throughput: {throughput} tokens/s")
assert throughput >= 140 self.assertGreaterEqual(throughput, 140)
def test_gptq_module(self): def test_gptq_module(self):
check_quant_method(self.MODEL_PATH, use_marlin_kernel=False) check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment