Unverified Commit bc1534ff authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134)


Co-authored-by: default avatarSehoon Kim <kssteven418@gmail.com>
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatarSehoon Kim <sehoon@x.ai>
parent 3a391812
......@@ -95,7 +95,7 @@ jobs:
strategy:
fail-fast: false
matrix:
range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100]
range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-48, 48-100]
steps:
- name: Checkout code
uses: actions/checkout@v3
......
......@@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
import math
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -37,7 +35,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode
from flashinfer.decode import _get_range_buf, get_seq_lens
class WrapperDispatch(Enum):
......@@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend):
):
super().__init__()
self.is_multimodal = model_runner.model_config.is_multimodal
# Parse constants
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
......@@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
)
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
assert not (
model_runner.sliding_window_size is not None
......@@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend):
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = [
......@@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
)
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
......@@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend):
if not skip_prefill:
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
) # for verify
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify
self.draft_extend_cuda_graph_metadata = {} # For draft extend
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
......@@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend):
],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
......@@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
)
self.decode_cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = DecodeMetadata(decode_wrappers)
for i in range(self.num_wrappers):
decode_wrappers[i].begin_forward = partial(
fast_decode_plan, decode_wrappers[i]
)
elif forward_mode.is_target_verify():
prefill_wrappers = []
for i in range(self.num_wrappers):
......@@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
logits_soft_cap=logits_soft_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
......@@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode:
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
if wrapper.is_cuda_graph_enabled:
# Directly write to the cuda graph input buffer
kv_indices = wrapper._paged_kv_indices_buf
else:
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
......@@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode:
self.req_to_token.shape[1],
)
else:
assert isinstance(spec_info, EagleDraftInput)
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
wrapper.begin_forward(
kv_indptr,
kv_indices,
......@@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill:
def update(
self,
req_pool_indices: torch.Tnesor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
......@@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
def update_single_wrapper(
self,
req_pool_indices: torch.Tnesor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
......@@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill:
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
......@@ -954,7 +961,10 @@ class FlashInferMultiStepDraftBackend:
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
self,
forward_batch: ForwardBatch,
kv_indices_buffer: torch.Tensor,
call_fn: Callable,
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
......@@ -1042,17 +1052,15 @@ class FlashInferMultiStepDraftBackend:
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
forward_batch.batch_size
][0]
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size,
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
......@@ -1113,6 +1121,11 @@ def should_use_tensor_core(
return False
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global_override_indptr_cpu = None
def fast_decode_plan(
self,
indptr: torch.Tensor,
......@@ -1142,6 +1155,9 @@ def fast_decode_plan(
if logits_soft_cap is None:
logits_soft_cap = 0.0
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
......@@ -1154,7 +1170,7 @@ def fast_decode_plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# Skip these copies
# Skip these copies because we directly write to them during prepartion
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
......@@ -1162,6 +1178,7 @@ def fast_decode_plan(
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
......@@ -1184,27 +1201,55 @@ def fast_decode_plan(
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache
stream = torch.cuda.current_stream()
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr.to("cpu"),
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
empty_q_data,
empty_kv_cache,
stream.cuda_stream,
indptr_host = (
global_override_indptr_cpu
if global_override_indptr_cpu is not None
else indptr.cpu()
)
if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens(
indptr_host, self.last_page_len[:batch_size], page_size
)
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
torch.cuda.current_stream().cuda_stream,
)
else:
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
self.empty_q_data,
self.empty_kv_cache,
torch.cuda.current_stream().cuda_stream,
)
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
......
......@@ -578,10 +578,12 @@ class TritonMultiStepDraftBackend:
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size,
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
......
......@@ -396,16 +396,10 @@ class CudaGraphRunner:
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global_graph_memory_pool = graph.pool()
return graph, out
......
......@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess(
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
parent_list = torch.cat(parents_list[:-1], dim=1)
if len(parents_list) > 1:
parent_list = torch.cat(parents_list[:-1], dim=1)
else:
batch_size = parents_list[0].shape[0]
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
return parent_list, top_scores_index, draft_tokens
......
from __future__ import annotations
import bisect
import time
from typing import TYPE_CHECKING, Callable
import torch
......@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner:
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
set_global_graph_memory_pool(graph.pool())
return graph, out
......@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner:
# Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch
forward_batch, forward_batch.batch_size
)
# Replay
......
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, List
import torch
import torch.nn.functional as F
......@@ -62,6 +62,7 @@ class EagleDraftInput:
batch.input_ids[pt : pt + extend_len] = torch.concat(
(input_ids[1:], self.verified_id[i].reshape(1))
)
pt += extend_len
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
......
import logging
import os
import time
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import torch
from huggingface_hub import snapshot_download
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
......@@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import get_available_gpu_memory
logger = logging.getLogger(__name__)
......@@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker):
nccl_port: int,
target_worker: TpModelWorker,
):
# Parse arguments
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.padded_static_len = self.speculative_num_steps + 1
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
self.target_worker = target_worker
# Override context length with target model's context length
server_args.context_length = target_worker.model_runner.model_config.context_len
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
# Do not capture cuda graph in `super().__init__()`
# We will capture it later
# It will be captured later.
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Lossy optimization by using hot tokens
# Load hot token ids
if server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
......@@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker):
else:
self.hot_token_id = None
# We share the allocator with a target worker. Draft/target worker
# owns its own KV cache.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Init target worker
# Init draft worker
super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
......@@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker):
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
self.target_worker = target_worker
# Parse arguments
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.server_args = server_args
self.use_nan_detection = self.server_args.enable_nan_detection
self.device = self.model_runner.device
self.gpu_id = self.model_runner.gpu_id
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
......@@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker):
backup_disable_cuda_graph
)
self.init_attention_backend()
self.init_cuda_graphs()
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
if server_args.attention_backend == "flashinfer":
if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
......@@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
elif server_args.attention_backend == "triton":
elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend,
)
......@@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
)
else:
raise ValueError(
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
)
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs()
def init_cuda_graphs(self):
"""Capture cuda graphs."""
......@@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input
if batch.return_logprob:
# Compute output logprobs using the sampler.
num_tokens_per_req = [
accept + 1 for accept in res.accept_length_per_req_cpu
]
self.target_worker.model_runner.update_output_logprobs(
logits_output,
batch.sampling_info,
batch.top_logprobs_nums,
batch.token_ids_logprobs,
res.verified_id,
# +1 for bonus token.
num_tokens_per_req=num_tokens_per_req,
)
# Add output logprobs to the request.
pt = 0
# NOTE: tolist() of these values are skipped when output is processed
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
verified_ids = res.verified_id.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for _ in range(num_tokens):
if req.return_logprob:
token_id = verified_ids[pt]
req.output_token_logprobs_val.append(next_token_logprobs[pt])
req.output_token_logprobs_idx.append(token_id)
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(
res.logits_output.next_token_top_logprobs_val[pt]
)
req.output_top_logprobs_idx.append(
res.logits_output.next_token_top_logprobs_idx[pt]
)
pt += 1
return logits_output, res, model_worker_batch
def forward_draft_extend(
......@@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker):
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
forward_batch.return_logprob = False
logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput)
......@@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# We don't need logprob for this extend.
original_return_logprob = batch.return_logprob
batch.return_logprob = False
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
......@@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker):
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.return_logprob = original_return_logprob
batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup
......@@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker):
draft_input.hidden_states = logits_output.hidden_states
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.use_nan_detection:
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
......
......@@ -165,7 +165,7 @@ class TestBenchServing(unittest.TestCase):
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
f'accept_length : {res["accept_length"]:.2f} \n'
)
self.assertLess(res["median_e2e_latency_ms"], 1100)
self.assertLess(res["median_e2e_latency_ms"], 900)
self.assertGreater(res["accept_length"], 2.99)
def test_moe_offline_throughput_default(self):
......
import json
import multiprocessing as mp
import os
import random
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from types import SimpleNamespace
from typing import List, Optional
import numpy as np
import requests
import torch
......@@ -21,6 +25,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
run_logprob_check,
)
torch_dtype = torch.float16
......@@ -260,11 +265,132 @@ class TestEAGLEServer(unittest.TestCase):
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.9)
self.assertGreater(avg_spec_accept_length, 3.5)
# Wait a little bit so that the memory check happens.
time.sleep(4)
def test_logprob_start_len(self):
logprob_start_len = 4
new_tokens = 4
prompts = [
"I have a very good idea on",
"Today is a sunndy day and",
]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": new_tokens,
},
"return_logprob": True,
"top_logprobs_num": 5,
"logprob_start_len": logprob_start_len,
},
)
response_json = response.json()
print(json.dumps(response_json, indent=2))
for res in response_json:
self.assertEqual(
res["meta_info"]["prompt_tokens"],
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
)
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
def test_logprob_match(self):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
def run_generate(
prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1
):
if isinstance(prompt, str):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
response = requests.post(
self.base_url + "/generate",
json={
**prompt_kwargs,
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
return response.json()
prompt = "I have a very good idea on how to"
gen = run_generate(prompt, return_logprob=True, logprob_start_len=0)
output_logprobs = np.array(
[x[0] for x in gen["meta_info"]["output_token_logprobs"]]
)
num_prompts_tokens = gen["meta_info"]["prompt_tokens"]
input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]]
output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]]
new_prompt = input_tokens + output_tokens
score = run_generate(
new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0
)
output_logprobs_score = np.array(
[
x[0]
for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:]
]
)
print(f"{output_logprobs[-10:]=}")
print(f"{output_logprobs_score[-10:]=}")
diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff)
self.assertLess(max_diff, 0.25)
def test_logprob_mixed(self):
args = []
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
# Llama 2 context length seems to be only 2k, so we can only test small length.
for input_len in [200, 500, 1000, 2000]:
for output_len in [4, 8]:
for logprob_start_len in [0, 100, 300, 800, 1998]:
for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]:
if logprob_start_len >= input_len:
continue
args.append(
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
)
)
random.shuffle(args)
func = partial(run_logprob_check, self)
with ThreadPoolExecutor(8) as executor:
list(executor.map(func, args))
class TestEAGLERetract(TestEAGLEServer):
@classmethod
......
......@@ -143,11 +143,11 @@ class TestGPTQModelDynamic(unittest.TestCase):
print(f"result = `{result}`")
assert "paris" in result["text"].lower()
self.assertIn("paris", result["text"].lower())
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
assert throughput >= 140
self.assertGreaterEqual(throughput, 140)
def test_gptq_module(self):
check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment