Unverified Commit 9ba1f097 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix logprob and normalized_logprob (#1428)

parent 282681b8
......@@ -54,7 +54,7 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 8
python3 run_suite.py --suite minimal --range-begin 0 --range-end 7
unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......@@ -73,7 +73,26 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 8
python3 run_suite.py --suite minimal --range-begin 7 --range-end 14
unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Install dependencies
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 14
performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......@@ -217,7 +236,7 @@ jobs:
finish:
needs: [
unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2,
unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3,
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
accuracy-test-1-gpu, accuracy-test-2-gpu
]
......
......@@ -91,7 +91,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
# Node 1
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1
```
### Supported Models
**Generative Models**
......
......@@ -164,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
return input_ids, reqs
......@@ -178,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test(
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, : bench_args.cut_len
]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
return reqs
......@@ -194,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
return reqs
......
......@@ -239,9 +239,12 @@ class RuntimeEndpoint(BaseBackend):
# Compute logprob
data = {
"text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0},
"sampling_params": {
"max_new_tokens": 0,
"temperature": 0,
},
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
"logprob_start_len": max(prompt_len - 2, 0), # for token healing
}
obj = self._generate_http_request(s, data)
......
......@@ -9,7 +9,7 @@ import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional
import tqdm
......
......@@ -56,6 +56,7 @@ class AttentionBackend(ABC):
raise NotImplementedError()
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
......@@ -66,9 +67,11 @@ class AttentionBackend(ABC):
return self.forward_extend(q, k, v, layer, input_metadata)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
"""Run a forward for decode."""
raise NotImplementedError()
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
"""Run a forward for extend."""
raise NotImplementedError()
......@@ -299,6 +302,7 @@ class FlashInferAttnBackend(AttentionBackend):
)
if total_num_tokens >= global_config.layer_sync_threshold:
# TODO: Revisit this. Why is this synchronize needed?
torch.cuda.synchronize()
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......
......@@ -37,7 +37,7 @@ class LogitsProcessorOutput:
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of input tokens. shape: [#token, vocab_size]
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
......@@ -49,25 +49,39 @@ class LogitsProcessorOutput:
@dataclasses.dataclass
class LogitsMetadata:
forward_mode: ForwardMode
top_logprobs_nums: Optional[List[int]]
return_logprob: bool = False
return_top_logprob: bool = False
extend_seq_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
top_logprobs_nums: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: List[int] = None
logprob_start_lens_cpu: List[int] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if input_metadata.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
for extend_len, start_len in zip(
input_metadata.extend_seq_lens,
input_metadata.extend_logprob_start_lens_cpu,
)
]
else:
extend_logprob_pruned_lens_cpu = None
return cls(
forward_mode=input_metadata.forward_mode,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_start_loc=input_metadata.extend_start_loc,
return_logprob=input_metadata.return_logprob,
top_logprobs_nums=input_metadata.top_logprobs_nums,
return_logprob=input_metadata.return_logprob,
return_top_logprob=return_top_logprob,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
)
......@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module):
def _get_normalized_prompt_logprobs(
self,
input_token_logprobs: torch.Tensor,
cum_start_len0: torch.Tensor,
cum_start_len1: torch.Tensor,
logits_metadata: LogitsMetadata,
):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
)
start = logits_metadata.extend_start_loc.clone() - cum_start_len0
end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
start = torch.zeros_like(pruned_lens)
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
end = torch.clamp(
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
)
sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
)
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
return normalized_prompt_logprobs
@staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
if logits_metadata.forward_mode.is_decode():
output_top_logprobs = []
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
return None, output_top_logprobs
else:
# TODO: vectorize the code below
input_top_logprobs, output_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
start_len = logits_metadata.logprob_start_lens_cpu[i]
pruned_len = extend_seq_len - start_len
if extend_seq_len == 0:
pt = 0
for k, pruned_len in zip(
logits_metadata.top_logprobs_nums,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_top_logprobs.append([])
output_top_logprobs.append([])
continue
k = logits_metadata.top_logprobs_nums[i]
input_top_logprobs.append(
[
list(zip(values[pt + j][:k], indices[pt + j][:k]))
......@@ -167,10 +173,7 @@ class LogitsProcessor(nn.Module):
last_index = None
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1
)
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T)
......@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module):
output_top_logprobs=None,
)
else:
# When logprob is requested, compute the logits for all tokens.
if logits_metadata.forward_mode.is_decode():
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
if logits_metadata.forward_mode.is_decode():
if logits_metadata.return_top_logprob:
output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
output_top_logprobs = None
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
......@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module):
output_top_logprobs=output_top_logprobs,
)
else:
# Slice the requested tokens to compute logprob
pt, states, pruned_input_ids = 0, [], []
for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
start_len = logits_metadata.logprob_start_lens_cpu[i]
for start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
states.append(hidden_states[pt + start_len : pt + extend_len])
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len
# Compute the logits and logprobs for all required tokens
states = torch.cat(states, dim=0)
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
cum_start_len1 = torch.tensor(
logits_metadata.logprob_start_lens_cpu, device="cuda"
).cumsum(0)
cum_start_len0 = torch.zeros_like(cum_start_len1)
cum_start_len0[1:] = cum_start_len1[:-1]
all_logits = torch.matmul(states, weight.T)
if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits)
......@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module):
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
if logits_metadata.return_top_logprob:
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
else:
input_top_logprobs = output_top_logprobs = None
last_logprobs = all_logprobs[last_index - cum_start_len1]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
torch.cat(
[
torch.cat(pruned_input_ids)[1:],
torch.tensor([0], device="cuda"),
]
),
]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
input_token_logprobs,
cum_start_len0,
cum_start_len1,
logits_metadata,
)
# Remove the last token logprob for the prefill tokens.
input_token_logprobs = input_token_logprobs[:-1]
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
......
......@@ -20,7 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import copy
import uuid
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
......@@ -43,6 +43,7 @@ class GenerateReqInput:
# Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None
# If return logprobs, the start location in the prompt for returning logprobs.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
logprob_start_len: Optional[Union[List[int], int]] = None
# If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: Optional[Union[List[int], int]] = None
......
......@@ -19,7 +19,7 @@ limitations under the License.
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
......@@ -53,7 +53,7 @@ class BaseFinishReason:
self.is_error = is_error
def to_json(self):
raise NotImplementedError("Subclasses must implement this method")
raise NotImplementedError()
class FINISH_MATCHED_TOKEN(BaseFinishReason):
......@@ -105,7 +105,13 @@ class FINISH_ABORT(BaseFinishReason):
class Req:
"""Store all inforamtion of a request."""
def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None):
def __init__(
self,
rid: str,
origin_input_text: str,
origin_input_ids: Tuple[int],
lora_path: Optional[str] = None,
):
# Input and output info
self.rid = rid
self.origin_input_text = origin_input_text
......@@ -118,6 +124,10 @@ class Req:
# Memory info
self.req_pool_idx = None
# Check finish
self.tokenizer = None
self.finished_reason = None
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | surr_ids |
......@@ -136,7 +146,7 @@ class Req:
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
# For vision input
# For vision inputs
self.pixel_values = None
self.image_sizes = None
self.image_offsets = None
......@@ -144,31 +154,35 @@ class Req:
self.modalities = None
# Prefix info
self.extend_input_len = 0
self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None
# Sampling parameters
self.sampling_params = None
self.stream = False
# Check finish
self.tokenizer = None
self.finished_reason = None
# Logprobs
# Logprobs (arguments)
self.return_logprob = False
self.embedding = None
self.logprob_start_len = 0
self.top_logprobs_num = 0
# Logprobs (return value)
self.normalized_prompt_logprob = None
self.input_token_logprobs = None
self.input_top_logprobs = None
self.output_token_logprobs = []
self.output_top_logprobs = []
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
# Embedding
self.embedding = None
# Constrained decoding
self.regex_fsm: RegexGuide = None
......@@ -363,9 +377,13 @@ class ScheduleBatch:
return_logprob: bool = False
top_logprobs_nums: List[int] = None
# Stream
has_stream: bool = False
@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_logprob = any(req.return_logprob for req in reqs)
has_stream = any(req.stream for req in reqs)
return cls(
reqs=reqs,
......@@ -373,18 +391,15 @@ class ScheduleBatch:
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
return_logprob=return_logprob,
has_stream=has_stream,
)
def batch_size(self):
return len(self.reqs) if self.reqs else 0
return len(self.reqs)
def is_empty(self):
return len(self.reqs) == 0
def has_stream(self) -> bool:
# Return whether batch has at least 1 streaming request
return any(r.stream for r in self.reqs)
def alloc_req_slots(self, num_reqs):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
......@@ -427,8 +442,8 @@ class ScheduleBatch:
for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices_cpu[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
ext_len = seq_len - pre_len
seq_lens.append(seq_len)
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx][
......@@ -436,9 +451,19 @@ class ScheduleBatch:
] = req.prefix_indices
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
out_cache_loc[pt : pt + ext_len]
out_cache_loc[pt : pt + req.extend_input_len]
)
pt += ext_len
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
extend_logprob_start_len = req.extend_input_len - 1
req.extend_logprob_start_len = extend_logprob_start_len
pt += req.extend_input_len
# Set fields
with torch.device("cuda"):
......@@ -451,21 +476,13 @@ class ScheduleBatch:
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
self.running_bs = running_batch.batch_size()
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
prefix_lens_cpu.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
running_bs = running_batch.batch_size()
for req in running_batch.reqs:
req.fill_ids = req.origin_input_ids + req.output_ids
......@@ -473,12 +490,22 @@ class ScheduleBatch:
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
extend_num_tokens = self.extend_num_tokens + running_bs
self.merge(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens
self.prefix_lens_cpu = prefix_lens_cpu
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens_cpu.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
self.extend_lens_cpu.extend([1] * running_bs)
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
def check_decode_mem(self):
bs = self.batch_size()
......@@ -685,6 +712,7 @@ class ScheduleBatch:
self.out_cache_loc = None
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs)
self.has_stream = any(req.stream for req in self.reqs)
self.sampling_info.filter(unfinished_indices, new_indices)
......@@ -695,7 +723,6 @@ class ScheduleBatch:
self.sampling_info.merge(other.sampling_info)
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
......@@ -706,3 +733,4 @@ class ScheduleBatch:
self.out_cache_loc = None
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
self.has_stream = any(req.stream for req in self.reqs)
......@@ -197,8 +197,6 @@ class TokenizerManager:
if not_use_index
else obj.logprob_start_len[index]
)
if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1
top_logprobs_num = (
obj.top_logprobs_num
if not_use_index
......@@ -251,8 +249,6 @@ class TokenizerManager:
# Send to the controller
if self.is_generation:
if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
......@@ -349,8 +345,6 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
obj.logprob_start_len[index] = len(input_ids) - 1
pixel_values, image_hashes, image_sizes = (
await self._get_pixel_values(obj.image_data[index])
)
......
......@@ -278,7 +278,7 @@ class ModelTpServer:
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.has_stream():
if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
......@@ -360,9 +360,13 @@ class ModelTpServer:
# Only when pixel values is not None we have modalities
req.modalities = recv_req.modalites
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len == -1:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1
# Init regex FSM
if (
......@@ -384,7 +388,7 @@ class ModelTpServer:
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn(
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
......@@ -583,7 +587,7 @@ class ModelTpServer:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish conditions
pt = 0
logprob_pt = 0
for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
......@@ -607,10 +611,9 @@ class ModelTpServer:
self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob:
self.add_logprob_return_values(
i, req, pt, next_token_ids, logits_output
logprob_pt += self.add_logprob_return_values(
i, req, logprob_pt, next_token_ids, logits_output
)
pt += req.extend_input_len
else:
assert batch.extend_num_tokens != 0
logits_output = self.model_runner.forward(batch)
......@@ -638,48 +641,63 @@ class ModelTpServer:
def add_logprob_return_values(
self,
i,
i: int,
req: Req,
pt: int,
next_token_ids: List[int],
output: LogitsProcessorOutput,
):
"""Attach logprobs to the return values."""
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.input_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.input_token_logprobs = list(
zip(
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
req.fill_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
input_token_logprobs = output.input_token_logprobs[
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
]
input_token_ids = req.fill_ids[
len(req.fill_ids)
- num_input_logprobs
+ 1 : len(req.fill_ids)
- req.last_update_decode_tokens
]
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
if (
req.logprob_start_len == 0
): # The first token does not have logprob, pad it.
req.input_token_logprobs = [
(None, req.fill_ids[0])
] + req.input_token_logprobs
if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs.extend(
list(
zip(
output.input_token_logprobs[
pt
+ req.extend_input_len
+ num_input_logprobs
- 1
- req.last_update_decode_tokens : pt
+ req.extend_input_len
+ num_input_logprobs
- 1
],
req.fill_ids[-req.last_update_decode_tokens + 1 :],
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
],
)
)
)
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
if req.top_logprobs_num > 0:
if req.input_top_logprobs is None:
req.input_top_logprobs = output.input_top_logprobs[i]
......@@ -688,10 +706,12 @@ class ModelTpServer:
if req.last_update_decode_tokens != 0:
req.output_top_logprobs.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs.append(output.output_top_logprobs[i])
return num_input_logprobs
def forward_decode_batch(self, batch: ScheduleBatch):
# Check if decode out of memory
if not batch.check_decode_mem():
......
......@@ -193,7 +193,7 @@ class CudaGraphRunner:
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
return_logprob=False,
top_logprobs_nums=0,
top_logprobs_nums=[0] * bs,
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
)
return forward(input_ids, input_metadata.positions, input_metadata)
......
......@@ -81,7 +81,7 @@ class InputMetadata:
return_logprob: bool = False
top_logprobs_nums: List[int] = None
extend_seq_lens_cpu: List[int] = None
logprob_start_lens_cpu: List[int] = None
extend_logprob_start_lens_cpu: List[int] = None
# For multimodal
pixel_values: List[torch.Tensor] = None
......@@ -138,27 +138,13 @@ class InputMetadata:
self.positions = self.positions.to(torch.int64)
def compute_extend_infos(self, batch: ScheduleBatch):
extend_lens_cpu = [
len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs)
]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
self.extend_seq_lens_cpu = extend_lens_cpu
self.logprob_start_lens_cpu = [
(
min(
req.logprob_start_len - batch.prefix_lens_cpu[i],
extend_lens_cpu[i] - 1,
)
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
)
for i, req in enumerate(batch.reqs)
]
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
self.extend_seq_lens_cpu = batch.extend_lens_cpu
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
@classmethod
def from_schedule_batch(
......
......@@ -22,7 +22,7 @@ import os
import time
import uuid
from http import HTTPStatus
from typing import Dict, List, Optional
from typing import Dict, List
from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
......@@ -472,7 +472,7 @@ def v1_generate_request(
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
assert (
type(request.prompt) == first_prompt_type
type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
if len(all_requests) > 1 and request.n > 1:
raise ValueError(
......@@ -887,7 +887,7 @@ def v1_chat_generate_request(
input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs)
top_logprobs_nums.append(request.top_logprobs or 0)
sampling_params = {
"temperature": request.temperature,
......
......@@ -86,24 +86,24 @@ class SamplingBatchInfo:
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda"
reqs = batch.reqs
ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
ret.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
)
ret.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
)
ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
with torch.device("cuda"):
ret.temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
).view(-1, 1)
ret.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
)
ret.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int
)
ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
)
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
......@@ -116,7 +116,7 @@ class SamplingBatchInfo:
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
device=device,
device="cuda",
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
......
......@@ -11,16 +11,18 @@ suites = {
"test_chunked_prefill.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
"test_json_constrained.py",
"test_large_max_new_tokens.py",
"test_openai_server.py",
"test_json_constrained.py",
"test_pytorch_sampling_backend.py",
"test_server_args.py",
"test_skip_tokenizer_init.py",
"test_srt_endpoint.py",
"test_torch_compile.py",
"test_torchao.py",
"test_triton_attn_backend.py",
"test_pytorch_sampling_backend.py",
"test_update_weights.py",
"test_vision_openai_server.py",
"test_server_args.py",
],
"sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True
......
......@@ -33,13 +33,13 @@ class TestChunkedPrefill(unittest.TestCase):
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=32,
num_examples=64,
num_threads=32,
)
try:
metrics = run_eval(args)
assert metrics["score"] >= 0.6
assert metrics["score"] >= 0.65
finally:
kill_child_process(process.pid)
......
......@@ -17,7 +17,6 @@ class TestJSONConstrained(unittest.TestCase):
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.json_schema = json.dumps(
{
"type": "object",
......@@ -28,16 +27,13 @@ class TestJSONConstrained(unittest.TestCase):
"required": ["name", "population"],
}
)
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(
self.base_url + "/generate",
json={
......@@ -54,7 +50,6 @@ class TestJSONConstrained(unittest.TestCase):
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
headers=headers,
)
print(json.dumps(response.json()))
print("=" * 100)
......@@ -69,7 +64,7 @@ class TestJSONConstrained(unittest.TestCase):
self.run_decode()
def test_json_openai(self):
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
response = client.chat.completions.create(
model=self.model,
......
......@@ -75,11 +75,11 @@ class TestOpenAIServer(unittest.TestCase):
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.choices[0].logprobs.token_logprobs[0]
assert response.id
assert response.created
......@@ -143,7 +143,7 @@ class TestOpenAIServer(unittest.TestCase):
ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0]
)
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
......@@ -479,6 +479,22 @@ class TestOpenAIServer(unittest.TestCase):
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_penalty(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=32,
frequency_penalty=1.0,
)
text = response.choices[0].message.content
assert isinstance(text, str)
if __name__ == "__main__":
unittest.main()
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
"""
import json
import unittest
......@@ -39,7 +43,7 @@ class TestSRTEndpoint(unittest.TestCase):
"text": "The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
"max_new_tokens": 16,
"n": n,
},
"stream": stream,
......@@ -56,7 +60,8 @@ class TestSRTEndpoint(unittest.TestCase):
for line in response.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_json.append(json.loads(line[6:]))
print(json.dumps(response_json))
print(json.dumps(response_json, indent=2))
print("=" * 100)
def test_simple_decode(self):
......@@ -69,13 +74,50 @@ class TestSRTEndpoint(unittest.TestCase):
self.run_decode(n=3, stream=True)
def test_logprob(self):
for top_logprobs_num in [0, 3]:
for return_text in [True, False]:
self.run_decode(
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)
self.run_decode(
return_logprob=True,
top_logprobs_num=5,
return_text=True,
)
def test_logprob_start_len(self):
logprob_start_len = 4
new_tokens = 4
prompts = [
"I have a very good idea on",
"Today is a sunndy day and",
]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": new_tokens,
},
"return_logprob": True,
"top_logprobs_num": 5,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
response_json = response.json()
print(json.dumps(response_json, indent=2))
for i, res in enumerate(response_json):
assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
res["meta_info"]["input_token_logprobs"]
)
assert prompts[i].endswith(
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
)
assert res["meta_info"]["completion_tokens"] == new_tokens
assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
res["text"] == "".join(
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment