Unverified Commit 9ba1f097 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix logprob and normalized_logprob (#1428)

parent 282681b8
...@@ -54,7 +54,7 @@ jobs: ...@@ -54,7 +54,7 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 8 python3 run_suite.py --suite minimal --range-begin 0 --range-end 7
unit-test-backend-part-2: unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -73,7 +73,26 @@ jobs: ...@@ -73,7 +73,26 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 8 python3 run_suite.py --suite minimal --range-begin 7 --range-end 14
unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Install dependencies
run: |
pip install --upgrade pip
pip install -e "python[dev]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 14
performance-test-1-gpu-part-1: performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -217,7 +236,7 @@ jobs: ...@@ -217,7 +236,7 @@ jobs:
finish: finish:
needs: [ needs: [
unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3,
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
accuracy-test-1-gpu, accuracy-test-2-gpu accuracy-test-1-gpu, accuracy-test-2-gpu
] ]
......
...@@ -91,7 +91,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ...@@ -91,7 +91,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
# Node 1 # Node 1
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1
``` ```
### Supported Models ### Supported Models
**Generative Models** **Generative Models**
......
...@@ -164,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer): ...@@ -164,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
req.prefix_indices = [] req.prefix_indices = []
req.sampling_params = sampling_params req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req) reqs.append(req)
return input_ids, reqs return input_ids, reqs
...@@ -178,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test( ...@@ -178,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test(
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, : bench_args.cut_len i, : bench_args.cut_len
] ]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
return reqs return reqs
...@@ -194,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): ...@@ -194,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
req.prefix_indices = [] req.prefix_indices = []
req.sampling_params = sampling_params req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req) reqs.append(req)
return reqs return reqs
......
...@@ -239,9 +239,12 @@ class RuntimeEndpoint(BaseBackend): ...@@ -239,9 +239,12 @@ class RuntimeEndpoint(BaseBackend):
# Compute logprob # Compute logprob
data = { data = {
"text": [s.text_ + c for c in choices], "text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0}, "sampling_params": {
"max_new_tokens": 0,
"temperature": 0,
},
"return_logprob": True, "return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0), "logprob_start_len": max(prompt_len - 2, 0), # for token healing
} }
obj = self._generate_http_request(s, data) obj = self._generate_http_request(s, data)
......
...@@ -9,7 +9,7 @@ import uuid ...@@ -9,7 +9,7 @@ import uuid
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional
import tqdm import tqdm
......
...@@ -56,6 +56,7 @@ class AttentionBackend(ABC): ...@@ -56,6 +56,7 @@ class AttentionBackend(ABC):
raise NotImplementedError() raise NotImplementedError()
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError() raise NotImplementedError()
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
...@@ -66,9 +67,11 @@ class AttentionBackend(ABC): ...@@ -66,9 +67,11 @@ class AttentionBackend(ABC):
return self.forward_extend(q, k, v, layer, input_metadata) return self.forward_extend(q, k, v, layer, input_metadata)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
"""Run a forward for decode."""
raise NotImplementedError() raise NotImplementedError()
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
"""Run a forward for extend."""
raise NotImplementedError() raise NotImplementedError()
...@@ -299,6 +302,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -299,6 +302,7 @@ class FlashInferAttnBackend(AttentionBackend):
) )
if total_num_tokens >= global_config.layer_sync_threshold: if total_num_tokens >= global_config.layer_sync_threshold:
# TODO: Revisit this. Why is this synchronize needed?
torch.cuda.synchronize() torch.cuda.synchronize()
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......
...@@ -37,7 +37,7 @@ class LogitsProcessorOutput: ...@@ -37,7 +37,7 @@ class LogitsProcessorOutput:
# The normlaized logprobs of prompts. shape: [#seq] # The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor normalized_prompt_logprobs: torch.Tensor
# The logprobs of input tokens. shape: [#token, vocab_size] # The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor input_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
...@@ -49,25 +49,39 @@ class LogitsProcessorOutput: ...@@ -49,25 +49,39 @@ class LogitsProcessorOutput:
@dataclasses.dataclass @dataclasses.dataclass
class LogitsMetadata: class LogitsMetadata:
forward_mode: ForwardMode forward_mode: ForwardMode
top_logprobs_nums: Optional[List[int]]
return_logprob: bool = False return_logprob: bool = False
return_top_logprob: bool = False
extend_seq_lens: Optional[torch.Tensor] = None extend_seq_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None extend_seq_lens_cpu: Optional[List[int]] = None
top_logprobs_nums: Optional[List[int]] = None
extend_seq_lens_cpu: List[int] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
logprob_start_lens_cpu: List[int] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
@classmethod @classmethod
def from_input_metadata(cls, input_metadata: InputMetadata): def from_input_metadata(cls, input_metadata: InputMetadata):
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if input_metadata.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
for extend_len, start_len in zip(
input_metadata.extend_seq_lens,
input_metadata.extend_logprob_start_lens_cpu,
)
]
else:
extend_logprob_pruned_lens_cpu = None
return cls( return cls(
forward_mode=input_metadata.forward_mode, forward_mode=input_metadata.forward_mode,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_start_loc=input_metadata.extend_start_loc,
return_logprob=input_metadata.return_logprob,
top_logprobs_nums=input_metadata.top_logprobs_nums, top_logprobs_nums=input_metadata.top_logprobs_nums,
return_logprob=input_metadata.return_logprob,
return_top_logprob=return_top_logprob,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu, extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu, extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
) )
...@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module): ...@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module):
def _get_normalized_prompt_logprobs( def _get_normalized_prompt_logprobs(
self, self,
input_token_logprobs: torch.Tensor, input_token_logprobs: torch.Tensor,
cum_start_len0: torch.Tensor,
cum_start_len1: torch.Tensor,
logits_metadata: LogitsMetadata, logits_metadata: LogitsMetadata,
): ):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
)
start = logits_metadata.extend_start_loc.clone() - cum_start_len0 start = torch.zeros_like(pruned_lens)
end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1 start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) end = torch.clamp(
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
)
sum_logp = ( sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start] logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
) )
normalized_prompt_logprobs = sum_logp / ( normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
)
return normalized_prompt_logprobs return normalized_prompt_logprobs
@staticmethod @staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
if logits_metadata.forward_mode.is_decode(): if logits_metadata.forward_mode.is_decode():
output_top_logprobs = [] output_top_logprobs = []
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, k in enumerate(logits_metadata.top_logprobs_nums): for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
return None, output_top_logprobs return None, output_top_logprobs
else: else:
# TODO: vectorize the code below
input_top_logprobs, output_top_logprobs = [], [] input_top_logprobs, output_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu): pt = 0
start_len = logits_metadata.logprob_start_lens_cpu[i] for k, pruned_len in zip(
pruned_len = extend_seq_len - start_len logits_metadata.top_logprobs_nums,
logits_metadata.extend_logprob_pruned_lens_cpu,
if extend_seq_len == 0: ):
if pruned_len <= 0:
input_top_logprobs.append([]) input_top_logprobs.append([])
output_top_logprobs.append([]) output_top_logprobs.append([])
continue continue
k = logits_metadata.top_logprobs_nums[i]
input_top_logprobs.append( input_top_logprobs.append(
[ [
list(zip(values[pt + j][:k], indices[pt + j][:k])) list(zip(values[pt + j][:k], indices[pt + j][:k]))
...@@ -167,10 +173,7 @@ class LogitsProcessor(nn.Module): ...@@ -167,10 +173,7 @@ class LogitsProcessor(nn.Module):
last_index = None last_index = None
last_hidden = hidden_states last_hidden = hidden_states
else: else:
last_index = ( last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1
)
last_hidden = hidden_states[last_index] last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T) last_logits = torch.matmul(last_hidden, weight.T)
...@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module): ...@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module):
output_top_logprobs=None, output_top_logprobs=None,
) )
else: else:
# When logprob is requested, compute the logits for all tokens. last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
if logits_metadata.forward_mode.is_decode():
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
# Get the logprob of top-k tokens if logits_metadata.forward_mode.is_decode():
return_top_logprob = any( if logits_metadata.return_top_logprob:
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
output_top_logprobs = self.get_top_logprobs( output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata last_logprobs, logits_metadata
)[1] )[1]
else: else:
output_top_logprobs = None output_top_logprobs = None
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
...@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module): ...@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module):
output_top_logprobs=output_top_logprobs, output_top_logprobs=output_top_logprobs,
) )
else: else:
# Slice the requested tokens to compute logprob
pt, states, pruned_input_ids = 0, [], [] pt, states, pruned_input_ids = 0, [], []
for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu): for start_len, extend_len in zip(
start_len = logits_metadata.logprob_start_lens_cpu[i] logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
states.append(hidden_states[pt + start_len : pt + extend_len]) states.append(hidden_states[pt + start_len : pt + extend_len])
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len pt += extend_len
# Compute the logits and logprobs for all required tokens
states = torch.cat(states, dim=0) states = torch.cat(states, dim=0)
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
cum_start_len1 = torch.tensor(
logits_metadata.logprob_start_lens_cpu, device="cuda"
).cumsum(0)
cum_start_len0 = torch.zeros_like(cum_start_len1)
cum_start_len0[1:] = cum_start_len1[:-1]
all_logits = torch.matmul(states, weight.T) all_logits = torch.matmul(states, weight.T)
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = tensor_model_parallel_all_gather(all_logits)
...@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module): ...@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module):
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens # Get the logprob of top-k tokens
return_top_logprob = any( if logits_metadata.return_top_logprob:
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
input_top_logprobs, output_top_logprobs = self.get_top_logprobs( input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata all_logprobs, logits_metadata
) )
else: else:
input_top_logprobs = output_top_logprobs = None input_top_logprobs = output_top_logprobs = None
last_logprobs = all_logprobs[last_index - cum_start_len1] # Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
input_token_logprobs = all_logprobs[ input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"), torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]), torch.cat(
[
torch.cat(pruned_input_ids)[1:],
torch.tensor([0], device="cuda"),
]
),
] ]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
input_token_logprobs, input_token_logprobs,
cum_start_len0,
cum_start_len1,
logits_metadata, logits_metadata,
) )
# Remove the last token logprob for the prefill tokens.
input_token_logprobs = input_token_logprobs[:-1]
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
......
...@@ -20,7 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). ...@@ -20,7 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import copy import copy
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
...@@ -43,6 +43,7 @@ class GenerateReqInput: ...@@ -43,6 +43,7 @@ class GenerateReqInput:
# Whether to return logprobs. # Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
# If return logprobs, the start location in the prompt for returning logprobs. # If return logprobs, the start location in the prompt for returning logprobs.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# If return logprobs, the number of top logprobs to return at each position. # If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: Optional[Union[List[int], int]] = None top_logprobs_num: Optional[Union[List[int], int]] = None
......
...@@ -19,7 +19,7 @@ limitations under the License. ...@@ -19,7 +19,7 @@ limitations under the License.
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -53,7 +53,7 @@ class BaseFinishReason: ...@@ -53,7 +53,7 @@ class BaseFinishReason:
self.is_error = is_error self.is_error = is_error
def to_json(self): def to_json(self):
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError()
class FINISH_MATCHED_TOKEN(BaseFinishReason): class FINISH_MATCHED_TOKEN(BaseFinishReason):
...@@ -105,7 +105,13 @@ class FINISH_ABORT(BaseFinishReason): ...@@ -105,7 +105,13 @@ class FINISH_ABORT(BaseFinishReason):
class Req: class Req:
"""Store all inforamtion of a request.""" """Store all inforamtion of a request."""
def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None): def __init__(
self,
rid: str,
origin_input_text: str,
origin_input_ids: Tuple[int],
lora_path: Optional[str] = None,
):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
self.origin_input_text = origin_input_text self.origin_input_text = origin_input_text
...@@ -118,6 +124,10 @@ class Req: ...@@ -118,6 +124,10 @@ class Req:
# Memory info # Memory info
self.req_pool_idx = None self.req_pool_idx = None
# Check finish
self.tokenizer = None
self.finished_reason = None
# For incremental decoding # For incremental decoding
# ----- | --------- read_ids -------| # ----- | --------- read_ids -------|
# ----- | surr_ids | # ----- | surr_ids |
...@@ -136,7 +146,7 @@ class Req: ...@@ -136,7 +146,7 @@ class Req:
# this does not include the jump forward tokens. # this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0 self.completion_tokens_wo_jump_forward = 0
# For vision input # For vision inputs
self.pixel_values = None self.pixel_values = None
self.image_sizes = None self.image_sizes = None
self.image_offsets = None self.image_offsets = None
...@@ -144,31 +154,35 @@ class Req: ...@@ -144,31 +154,35 @@ class Req:
self.modalities = None self.modalities = None
# Prefix info # Prefix info
self.extend_input_len = 0
self.prefix_indices = [] self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None self.last_node = None
# Sampling parameters # Sampling parameters
self.sampling_params = None self.sampling_params = None
self.stream = False self.stream = False
# Check finish # Logprobs (arguments)
self.tokenizer = None
self.finished_reason = None
# Logprobs
self.return_logprob = False self.return_logprob = False
self.embedding = None
self.logprob_start_len = 0 self.logprob_start_len = 0
self.top_logprobs_num = 0 self.top_logprobs_num = 0
# Logprobs (return value)
self.normalized_prompt_logprob = None self.normalized_prompt_logprob = None
self.input_token_logprobs = None self.input_token_logprobs = None
self.input_top_logprobs = None self.input_top_logprobs = None
self.output_token_logprobs = [] self.output_token_logprobs = []
self.output_top_logprobs = [] self.output_top_logprobs = []
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens # The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs # and should be updated for the decode logprobs
self.last_update_decode_tokens = 0 self.last_update_decode_tokens = 0
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
# Embedding
self.embedding = None
# Constrained decoding # Constrained decoding
self.regex_fsm: RegexGuide = None self.regex_fsm: RegexGuide = None
...@@ -363,9 +377,13 @@ class ScheduleBatch: ...@@ -363,9 +377,13 @@ class ScheduleBatch:
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
# Stream
has_stream: bool = False
@classmethod @classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_logprob = any(req.return_logprob for req in reqs) return_logprob = any(req.return_logprob for req in reqs)
has_stream = any(req.stream for req in reqs)
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -373,18 +391,15 @@ class ScheduleBatch: ...@@ -373,18 +391,15 @@ class ScheduleBatch:
token_to_kv_pool=token_to_kv_pool, token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache, tree_cache=tree_cache,
return_logprob=return_logprob, return_logprob=return_logprob,
has_stream=has_stream,
) )
def batch_size(self): def batch_size(self):
return len(self.reqs) if self.reqs else 0 return len(self.reqs)
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
def has_stream(self) -> bool:
# Return whether batch has at least 1 streaming request
return any(r.stream for r in self.reqs)
def alloc_req_slots(self, num_reqs): def alloc_req_slots(self, num_reqs):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs) req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None: if req_pool_indices is None:
...@@ -427,8 +442,8 @@ class ScheduleBatch: ...@@ -427,8 +442,8 @@ class ScheduleBatch:
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices_cpu[i] req.req_pool_idx = req_pool_indices_cpu[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
ext_len = seq_len - pre_len
seq_lens.append(seq_len) seq_lens.append(seq_len)
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0: if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx][ self.req_to_token_pool.req_to_token[req.req_pool_idx][
...@@ -436,9 +451,19 @@ class ScheduleBatch: ...@@ -436,9 +451,19 @@ class ScheduleBatch:
] = req.prefix_indices ] = req.prefix_indices
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
out_cache_loc[pt : pt + ext_len] out_cache_loc[pt : pt + req.extend_input_len]
) )
pt += ext_len
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
extend_logprob_start_len = req.extend_input_len - 1
req.extend_logprob_start_len = extend_logprob_start_len
pt += req.extend_input_len
# Set fields # Set fields
with torch.device("cuda"): with torch.device("cuda"):
...@@ -451,21 +476,13 @@ class ScheduleBatch: ...@@ -451,21 +476,13 @@ class ScheduleBatch:
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs] self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED self.forward_mode = ForwardMode.MIXED
self.running_bs = running_batch.batch_size() running_bs = running_batch.batch_size()
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
prefix_lens_cpu.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
for req in running_batch.reqs: for req in running_batch.reqs:
req.fill_ids = req.origin_input_ids + req.output_ids req.fill_ids = req.origin_input_ids + req.output_ids
...@@ -473,12 +490,22 @@ class ScheduleBatch: ...@@ -473,12 +490,22 @@ class ScheduleBatch:
input_ids = torch.cat([self.input_ids, running_batch.input_ids]) input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size() extend_num_tokens = self.extend_num_tokens + running_bs
self.merge(running_batch) self.merge(running_batch)
self.input_ids = input_ids self.input_ids = input_ids
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.prefix_lens_cpu = prefix_lens_cpu
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens_cpu.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
self.extend_lens_cpu.extend([1] * running_bs)
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
def check_decode_mem(self): def check_decode_mem(self):
bs = self.batch_size() bs = self.batch_size()
...@@ -685,6 +712,7 @@ class ScheduleBatch: ...@@ -685,6 +712,7 @@ class ScheduleBatch:
self.out_cache_loc = None self.out_cache_loc = None
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
self.has_stream = any(req.stream for req in self.reqs)
self.sampling_info.filter(unfinished_indices, new_indices) self.sampling_info.filter(unfinished_indices, new_indices)
...@@ -695,7 +723,6 @@ class ScheduleBatch: ...@@ -695,7 +723,6 @@ class ScheduleBatch:
self.sampling_info.merge(other.sampling_info) self.sampling_info.merge(other.sampling_info)
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat( self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
...@@ -706,3 +733,4 @@ class ScheduleBatch: ...@@ -706,3 +733,4 @@ class ScheduleBatch:
self.out_cache_loc = None self.out_cache_loc = None
self.top_logprobs_nums.extend(other.top_logprobs_nums) self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
self.has_stream = any(req.stream for req in self.reqs)
...@@ -197,8 +197,6 @@ class TokenizerManager: ...@@ -197,8 +197,6 @@ class TokenizerManager:
if not_use_index if not_use_index
else obj.logprob_start_len[index] else obj.logprob_start_len[index]
) )
if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1
top_logprobs_num = ( top_logprobs_num = (
obj.top_logprobs_num obj.top_logprobs_num
if not_use_index if not_use_index
...@@ -251,8 +249,6 @@ class TokenizerManager: ...@@ -251,8 +249,6 @@ class TokenizerManager:
# Send to the controller # Send to the controller
if self.is_generation: if self.is_generation:
if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid, rid,
input_text, input_text,
...@@ -349,8 +345,6 @@ class TokenizerManager: ...@@ -349,8 +345,6 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index]) sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation: if self.is_generation:
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
obj.logprob_start_len[index] = len(input_ids) - 1
pixel_values, image_hashes, image_sizes = ( pixel_values, image_hashes, image_sizes = (
await self._get_pixel_values(obj.image_data[index]) await self._get_pixel_values(obj.image_data[index])
) )
......
...@@ -278,7 +278,7 @@ class ModelTpServer: ...@@ -278,7 +278,7 @@ class ModelTpServer:
self.running_batch = None self.running_batch = None
break break
if self.out_pyobjs and self.running_batch.has_stream(): if self.out_pyobjs and self.running_batch.has_stream:
break break
else: else:
self.check_memory() self.check_memory()
...@@ -360,9 +360,13 @@ class ModelTpServer: ...@@ -360,9 +360,13 @@ class ModelTpServer:
# Only when pixel values is not None we have modalities # Only when pixel values is not None we have modalities
req.modalities = recv_req.modalites req.modalities = recv_req.modalites
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream req.stream = recv_req.stream
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len == -1:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1
# Init regex FSM # Init regex FSM
if ( if (
...@@ -384,7 +388,7 @@ class ModelTpServer: ...@@ -384,7 +388,7 @@ class ModelTpServer:
# Truncate prompts that are too long # Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len: if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn( logger.warning(
"Request length is longer than the KV cache pool size or " "Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!" "the max context length. Truncated!!!"
) )
...@@ -583,7 +587,7 @@ class ModelTpServer: ...@@ -583,7 +587,7 @@ class ModelTpServer:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish conditions # Check finish conditions
pt = 0 logprob_pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req: if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished # Inflight reqs' prefill is not finished
...@@ -607,10 +611,9 @@ class ModelTpServer: ...@@ -607,10 +611,9 @@ class ModelTpServer:
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob: if req.return_logprob:
self.add_logprob_return_values( logprob_pt += self.add_logprob_return_values(
i, req, pt, next_token_ids, logits_output i, req, logprob_pt, next_token_ids, logits_output
) )
pt += req.extend_input_len
else: else:
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
logits_output = self.model_runner.forward(batch) logits_output = self.model_runner.forward(batch)
...@@ -638,48 +641,63 @@ class ModelTpServer: ...@@ -638,48 +641,63 @@ class ModelTpServer:
def add_logprob_return_values( def add_logprob_return_values(
self, self,
i, i: int,
req: Req, req: Req,
pt: int, pt: int,
next_token_ids: List[int], next_token_ids: List[int],
output: LogitsProcessorOutput, output: LogitsProcessorOutput,
): ):
"""Attach logprobs to the return values."""
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
if req.normalized_prompt_logprob is None: if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.input_token_logprobs is None: if req.input_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. input_token_logprobs = output.input_token_logprobs[
req.input_token_logprobs = list( pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
zip( ]
output.input_token_logprobs[pt : pt + req.extend_input_len - 1], input_token_ids = req.fill_ids[
req.fill_ids[-req.extend_input_len + 1 :], len(req.fill_ids)
) - num_input_logprobs
) + 1 : len(req.fill_ids)
if req.logprob_start_len == 0: - req.last_update_decode_tokens
]
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
if (
req.logprob_start_len == 0
): # The first token does not have logprob, pad it.
req.input_token_logprobs = [ req.input_token_logprobs = [
(None, req.fill_ids[0]) (None, req.fill_ids[0])
] + req.input_token_logprobs ] + req.input_token_logprobs
if req.last_update_decode_tokens != 0: if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs.extend( req.output_token_logprobs.extend(
list( list(
zip( zip(
output.input_token_logprobs[ output.input_token_logprobs[
pt pt
+ req.extend_input_len + num_input_logprobs
- 1
- req.last_update_decode_tokens : pt - req.last_update_decode_tokens : pt
+ req.extend_input_len + num_input_logprobs
- 1 - 1
], ],
req.fill_ids[-req.last_update_decode_tokens + 1 :], req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
],
) )
) )
) )
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
if req.input_top_logprobs is None: if req.input_top_logprobs is None:
req.input_top_logprobs = output.input_top_logprobs[i] req.input_top_logprobs = output.input_top_logprobs[i]
...@@ -688,10 +706,12 @@ class ModelTpServer: ...@@ -688,10 +706,12 @@ class ModelTpServer:
if req.last_update_decode_tokens != 0: if req.last_update_decode_tokens != 0:
req.output_top_logprobs.extend( req.output_top_logprobs.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :] output.input_top_logprobs[i][-req.last_update_decode_tokens :]
) )
req.output_top_logprobs.append(output.output_top_logprobs[i]) req.output_top_logprobs.append(output.output_top_logprobs[i])
return num_input_logprobs
def forward_decode_batch(self, batch: ScheduleBatch): def forward_decode_batch(self, batch: ScheduleBatch):
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(): if not batch.check_decode_mem():
......
...@@ -193,7 +193,7 @@ class CudaGraphRunner: ...@@ -193,7 +193,7 @@ class CudaGraphRunner:
attn_backend=self.model_runner.attn_backend, attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_logprob=False, return_logprob=False,
top_logprobs_nums=0, top_logprobs_nums=[0] * bs,
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
) )
return forward(input_ids, input_metadata.positions, input_metadata) return forward(input_ids, input_metadata.positions, input_metadata)
......
...@@ -81,7 +81,7 @@ class InputMetadata: ...@@ -81,7 +81,7 @@ class InputMetadata:
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
extend_seq_lens_cpu: List[int] = None extend_seq_lens_cpu: List[int] = None
logprob_start_lens_cpu: List[int] = None extend_logprob_start_lens_cpu: List[int] = None
# For multimodal # For multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
...@@ -138,27 +138,13 @@ class InputMetadata: ...@@ -138,27 +138,13 @@ class InputMetadata:
self.positions = self.positions.to(torch.int64) self.positions = self.positions.to(torch.int64)
def compute_extend_infos(self, batch: ScheduleBatch): def compute_extend_infos(self, batch: ScheduleBatch):
extend_lens_cpu = [ self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs)
]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
self.extend_seq_lens_cpu = batch.extend_lens_cpu
self.extend_seq_lens_cpu = extend_lens_cpu self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
self.logprob_start_lens_cpu = [
(
min(
req.logprob_start_len - batch.prefix_lens_cpu[i],
extend_lens_cpu[i] - 1,
)
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
)
for i, req in enumerate(batch.reqs)
]
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(
......
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import time import time
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional from typing import Dict, List
from fastapi import HTTPException, Request, UploadFile from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
...@@ -472,7 +472,7 @@ def v1_generate_request( ...@@ -472,7 +472,7 @@ def v1_generate_request(
first_prompt_type = type(all_requests[0].prompt) first_prompt_type = type(all_requests[0].prompt)
for request in all_requests: for request in all_requests:
assert ( assert (
type(request.prompt) == first_prompt_type type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings" ), "All prompts must be of the same type in file input settings"
if len(all_requests) > 1 and request.n > 1: if len(all_requests) > 1 and request.n > 1:
raise ValueError( raise ValueError(
...@@ -887,7 +887,7 @@ def v1_chat_generate_request( ...@@ -887,7 +887,7 @@ def v1_chat_generate_request(
input_ids.append(prompt_ids) input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs) return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1) logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs) top_logprobs_nums.append(request.top_logprobs or 0)
sampling_params = { sampling_params = {
"temperature": request.temperature, "temperature": request.temperature,
......
...@@ -86,24 +86,24 @@ class SamplingBatchInfo: ...@@ -86,24 +86,24 @@ class SamplingBatchInfo:
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda"
reqs = batch.reqs reqs = batch.reqs
ret = cls(vocab_size=vocab_size) ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.tensor( with torch.device("cuda"):
[r.sampling_params.temperature for r in reqs], ret.temperatures = torch.tensor(
dtype=torch.float, [r.sampling_params.temperature for r in reqs],
device=device, dtype=torch.float,
).view(-1, 1) ).view(-1, 1)
ret.top_ps = torch.tensor( ret.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device [r.sampling_params.top_p for r in reqs], dtype=torch.float
) )
ret.top_ks = torch.tensor( ret.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device [r.sampling_params.top_k for r in reqs], dtype=torch.int
) )
ret.min_ps = torch.tensor( ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device [r.sampling_params.min_p for r in reqs], dtype=torch.float
) )
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at # Each penalizers will do nothing if they evaluate themselves as not required by looking at
...@@ -116,7 +116,7 @@ class SamplingBatchInfo: ...@@ -116,7 +116,7 @@ class SamplingBatchInfo:
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size, vocab_size=vocab_size,
batch=batch, batch=batch,
device=device, device="cuda",
Penalizers={ Penalizers={
penaltylib.BatchedFrequencyPenalizer, penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer, penaltylib.BatchedMinNewTokensPenalizer,
......
...@@ -11,16 +11,18 @@ suites = { ...@@ -11,16 +11,18 @@ suites = {
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
"test_json_constrained.py",
"test_large_max_new_tokens.py", "test_large_max_new_tokens.py",
"test_openai_server.py", "test_openai_server.py",
"test_json_constrained.py", "test_pytorch_sampling_backend.py",
"test_server_args.py",
"test_skip_tokenizer_init.py", "test_skip_tokenizer_init.py",
"test_srt_endpoint.py",
"test_torch_compile.py", "test_torch_compile.py",
"test_torchao.py",
"test_triton_attn_backend.py", "test_triton_attn_backend.py",
"test_pytorch_sampling_backend.py",
"test_update_weights.py", "test_update_weights.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_server_args.py",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True "sampling/penaltylib/**/test_*.py", recursive=True
......
...@@ -33,13 +33,13 @@ class TestChunkedPrefill(unittest.TestCase): ...@@ -33,13 +33,13 @@ class TestChunkedPrefill(unittest.TestCase):
base_url=base_url, base_url=base_url,
model=model, model=model,
eval_name="mmlu", eval_name="mmlu",
num_examples=32, num_examples=64,
num_threads=32, num_threads=32,
) )
try: try:
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.6 assert metrics["score"] >= 0.65
finally: finally:
kill_child_process(process.pid) kill_child_process(process.pid)
......
...@@ -17,7 +17,6 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -17,7 +17,6 @@ class TestJSONConstrained(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.json_schema = json.dumps( cls.json_schema = json.dumps(
{ {
"type": "object", "type": "object",
...@@ -28,16 +27,13 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -28,16 +27,13 @@ class TestJSONConstrained(unittest.TestCase):
"required": ["name", "population"], "required": ["name", "population"],
} }
) )
cls.process = popen_launch_server( cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
...@@ -54,7 +50,6 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -54,7 +50,6 @@ class TestJSONConstrained(unittest.TestCase):
"top_logprobs_num": top_logprobs_num, "top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0, "logprob_start_len": 0,
}, },
headers=headers,
) )
print(json.dumps(response.json())) print(json.dumps(response.json()))
print("=" * 100) print("=" * 100)
...@@ -69,7 +64,7 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -69,7 +64,7 @@ class TestJSONConstrained(unittest.TestCase):
self.run_decode() self.run_decode()
def test_json_openai(self): def test_json_openai(self):
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
......
...@@ -75,11 +75,11 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -75,11 +75,11 @@ class TestOpenAIServer(unittest.TestCase):
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict) assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1]) ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0 assert ret_num_top_logprobs > 0
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.choices[0].logprobs.token_logprobs[0]
assert response.id assert response.id
assert response.created assert response.created
...@@ -143,7 +143,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -143,7 +143,7 @@ class TestOpenAIServer(unittest.TestCase):
ret_num_top_logprobs = len( ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0] response.choices[0].logprobs.top_logprobs[0]
) )
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0 assert ret_num_top_logprobs > 0
...@@ -479,6 +479,22 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -479,6 +479,22 @@ class TestOpenAIServer(unittest.TestCase):
assert isinstance(js_obj["name"], str) assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int) assert isinstance(js_obj["population"], int)
def test_penalty(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=32,
frequency_penalty=1.0,
)
text = response.choices[0].message.content
assert isinstance(text, str)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
"""
import json import json
import unittest import unittest
...@@ -39,7 +43,7 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -39,7 +43,7 @@ class TestSRTEndpoint(unittest.TestCase):
"text": "The capital of France is", "text": "The capital of France is",
"sampling_params": { "sampling_params": {
"temperature": 0 if n == 1 else 0.5, "temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32, "max_new_tokens": 16,
"n": n, "n": n,
}, },
"stream": stream, "stream": stream,
...@@ -56,7 +60,8 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -56,7 +60,8 @@ class TestSRTEndpoint(unittest.TestCase):
for line in response.iter_lines(): for line in response.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]": if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_json.append(json.loads(line[6:])) response_json.append(json.loads(line[6:]))
print(json.dumps(response_json))
print(json.dumps(response_json, indent=2))
print("=" * 100) print("=" * 100)
def test_simple_decode(self): def test_simple_decode(self):
...@@ -69,13 +74,50 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -69,13 +74,50 @@ class TestSRTEndpoint(unittest.TestCase):
self.run_decode(n=3, stream=True) self.run_decode(n=3, stream=True)
def test_logprob(self): def test_logprob(self):
for top_logprobs_num in [0, 3]: self.run_decode(
for return_text in [True, False]: return_logprob=True,
self.run_decode( top_logprobs_num=5,
return_logprob=True, return_text=True,
top_logprobs_num=top_logprobs_num, )
return_text=return_text,
) def test_logprob_start_len(self):
logprob_start_len = 4
new_tokens = 4
prompts = [
"I have a very good idea on",
"Today is a sunndy day and",
]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": new_tokens,
},
"return_logprob": True,
"top_logprobs_num": 5,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
response_json = response.json()
print(json.dumps(response_json, indent=2))
for i, res in enumerate(response_json):
assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
res["meta_info"]["input_token_logprobs"]
)
assert prompts[i].endswith(
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
)
assert res["meta_info"]["completion_tokens"] == new_tokens
assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
res["text"] == "".join(
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment