"docs/vscode:/vscode.git/clone" did not exist on "cf03592743abbb0b06ba609ebba3847a8ada4a47"
Unverified Commit aee4f523 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix logit processor bugs (#427)

parent 7023f413
...@@ -297,7 +297,6 @@ curl http://localhost:30000/generate \ ...@@ -297,7 +297,6 @@ curl http://localhost:30000/generate \
Learn more about the argument format [here](docs/sampling_params.md). Learn more about the argument format [here](docs/sampling_params.md).
### OpenAI Compatible API ### OpenAI Compatible API
In addition, the server supports an experimental OpenAI-compatible API. In addition, the server supports an experimental OpenAI-compatible API.
```python ```python
...@@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port ...@@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md). Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
## Benchmark And Performance ## Benchmark And Performance
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1 - Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
![llama_7b](assets/llama_7b.jpg) ![llama_7b](assets/llama_7b.jpg)
...@@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157 ...@@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157
} }
``` ```
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/2312.07104)
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql). We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
"""Some Public API Definitions""" """Some Public API Definitions"""
import os
import re import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
...@@ -31,6 +32,7 @@ def function( ...@@ -31,6 +32,7 @@ def function(
def Runtime(*args, **kwargs): def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency # Avoid importing unnecessary dependency
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from sglang.srt.server import Runtime from sglang.srt.server import Runtime
return Runtime(*args, **kwargs) return Runtime(*args, **kwargs)
......
...@@ -14,7 +14,7 @@ except ImportError as e: ...@@ -14,7 +14,7 @@ except ImportError as e:
class Anthropic(BaseBackend): class Anthropic(BaseBackend):
def __init__(self, model_name): def __init__(self, model_name, *args, **kwargs):
super().__init__() super().__init__()
if isinstance(anthropic, Exception): if isinstance(anthropic, Exception):
...@@ -22,6 +22,7 @@ class Anthropic(BaseBackend): ...@@ -22,6 +22,7 @@ class Anthropic(BaseBackend):
self.model_name = model_name self.model_name = model_name
self.chat_template = get_chat_template("claude") self.chat_template = get_chat_template("claude")
self.client = anthropic.Anthropic(*args, **kwargs)
def get_chat_template(self): def get_chat_template(self):
return self.chat_template return self.chat_template
...@@ -41,7 +42,7 @@ class Anthropic(BaseBackend): ...@@ -41,7 +42,7 @@ class Anthropic(BaseBackend):
else: else:
system = "" system = ""
ret = anthropic.Anthropic().messages.create( ret = self.client.messages.create(
model=self.model_name, model=self.model_name,
system=system, system=system,
messages=messages, messages=messages,
...@@ -66,11 +67,11 @@ class Anthropic(BaseBackend): ...@@ -66,11 +67,11 @@ class Anthropic(BaseBackend):
else: else:
system = "" system = ""
with anthropic.Anthropic().messages.stream( with self.client.messages.stream(
model=self.model_name, model=self.model_name,
system=system, system=system,
messages=messages, messages=messages,
**sampling_params.to_anthropic_kwargs(), **sampling_params.to_anthropic_kwargs(),
) as stream: ) as stream:
for text in stream.text_stream: for text in stream.text_stream:
yield text, {} yield text, {}
\ No newline at end of file
...@@ -228,7 +228,7 @@ class OpenAI(BaseBackend): ...@@ -228,7 +228,7 @@ class OpenAI(BaseBackend):
prompt_tokens.append(ret_token) prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)] decision = choices[np.argmax(scores)]
return decision, scores, scores return decision, scores, None, None
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
......
...@@ -220,7 +220,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -220,7 +220,6 @@ class RuntimeEndpoint(BaseBackend):
"sampling_params": {"max_new_tokens": 0}, "sampling_params": {"max_new_tokens": 0},
"return_logprob": True, "return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0), "logprob_start_len": max(prompt_len - 2, 0),
"return_text_in_logprobs": True,
} }
self._add_images(s, data) self._add_images(s, data)
res = http_request( res = http_request(
......
...@@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module): ...@@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module):
for i in range(all_logprobs.shape[0]): for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i] k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k) t = all_logprobs[i].topk(k)
v_cpu = t.values.cpu().tolist() v_cpu = t.values.tolist()
p_cpu = t.indices.cpu().tolist() p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu))) decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs return None, decode_top_logprobs
else: else:
prefill_top_logprobs, decode_top_logprobs = [], [] prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0 pt = 0
# NOTE: the GPU-CPU overhead can be reduced # NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
for i in range(len(input_metadata.extend_seq_lens)): for i in range(len(extend_seq_lens_cpu)):
if extend_seq_lens_cpu[i] == 0: if extend_seq_lens_cpu[i] == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
continue continue
k = input_metadata.top_logprobs_nums[i] k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k) t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
vs_cpu = t.values.cpu().tolist() vs_cpu = t.values.tolist()
ps_cpu = t.indices.cpu().tolist() ps_cpu = t.indices.tolist()
prefill_top_logprobs.append( prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
) )
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_lens_cpu[i]
return prefill_top_logprobs, decode_top_logprobs return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
...@@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module): ...@@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module):
all_logits = all_logits[:, : self.config.vocab_size] all_logits = all_logits[:, : self.config.vocab_size]
all_logprobs = all_logits.float() all_logprobs = all_logits.float()
all_logits = None del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
all_logprobs, input_metadata if return_top_logprob:
) prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE: if input_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = all_logprobs last_logprobs = all_logprobs
return last_logits, ( return last_logits, (
None, None,
None, None,
decode_top_logprobs,
None, None,
decode_top_logprobs,
last_logprobs, last_logprobs,
) )
else: else:
...@@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module): ...@@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module):
) )
return last_logits, ( return last_logits, (
prefill_token_logprobs, prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs, prefill_top_logprobs,
decode_top_logprobs, decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs, last_logprobs,
) )
......
...@@ -25,7 +25,6 @@ class GenerateReqInput: ...@@ -25,7 +25,6 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output # Whether to stream output
stream: bool = False stream: bool = False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def post_init(self): def post_init(self):
is_single = isinstance(self.text, str) is_single = isinstance(self.text, str)
......
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import IntEnum, auto
from typing import List from typing import List
import numpy as np import numpy as np
...@@ -9,15 +9,15 @@ from sglang.srt.managers.router.radix_cache import RadixCache ...@@ -9,15 +9,15 @@ from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
class ForwardMode(Enum): class ForwardMode(IntEnum):
PREFILL = auto() PREFILL = auto()
EXTEND = auto() EXTEND = auto()
DECODE = auto() DECODE = auto()
class FinishReason(Enum): class FinishReason(IntEnum):
LENGTH = auto()
EOS_TOKEN = auto() EOS_TOKEN = auto()
LENGTH = auto()
STOP_STR = auto() STOP_STR = auto()
...@@ -31,6 +31,7 @@ class Req: ...@@ -31,6 +31,7 @@ class Req:
# Since jump forward may retokenize the prompt with partial outputs, # Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage. # we maintain the original prompt length to report the correct usage.
self.prompt_tokens = len(input_ids) self.prompt_tokens = len(input_ids)
# The number of decoded tokens for token usage report. Note that # The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens. # this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0 self.completion_tokens_wo_jump_forward = 0
...@@ -41,12 +42,11 @@ class Req: ...@@ -41,12 +42,11 @@ class Req:
self.image_offset = 0 self.image_offset = 0
self.pad_value = None self.pad_value = None
# Sampling parameters
self.sampling_params = None self.sampling_params = None
self.return_logprob = False
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.stream = False self.stream = False
# Check finish
self.tokenizer = None self.tokenizer = None
self.finished = False self.finished = False
self.finish_reason = None self.finish_reason = None
...@@ -56,13 +56,17 @@ class Req: ...@@ -56,13 +56,17 @@ class Req:
self.prefix_indices = [] self.prefix_indices = []
self.last_node = None self.last_node = None
# Logprobs
self.return_logprob = False
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None self.prefill_token_logprobs = None
self.decode_token_logprobs = None self.decode_token_logprobs = None
self.normalized_prompt_logprob = None
self.prefill_top_logprobs = None self.prefill_top_logprobs = None
self.decode_top_logprobs = None self.decode_top_logprobs = None
# For constrained decoding # Constrained decoding
self.regex_fsm = None self.regex_fsm = None
self.regex_fsm_state = 0 self.regex_fsm_state = 0
self.jump_forward_map = None self.jump_forward_map = None
...@@ -165,8 +169,8 @@ class Batch: ...@@ -165,8 +169,8 @@ class Batch:
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
# for processing logprobs # for processing logprobs
top_logprobs_nums: List[int] = None
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None
# for multimodal # for multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
...@@ -321,8 +325,8 @@ class Batch: ...@@ -321,8 +325,8 @@ class Batch:
) )
retracted_reqs = [] retracted_reqs = []
seq_lens_np = self.seq_lens.cpu().numpy() seq_lens_cpu = self.seq_lens.cpu().numpy()
req_pool_indices_np = self.req_pool_indices.cpu().numpy() req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
while self.token_to_kv_pool.available_size() < len(self.reqs): while self.token_to_kv_pool.available_size() < len(self.reqs):
idx = sorted_indices.pop() idx = sorted_indices.pop()
req = self.reqs[idx] req = self.reqs[idx]
...@@ -338,8 +342,8 @@ class Batch: ...@@ -338,8 +342,8 @@ class Batch:
# TODO: apply more fine-grained retraction # TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx] req_pool_indices_cpu[idx]
][: seq_lens_np[idx]] ][: seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices) self.token_to_kv_pool.dec_refs(token_indices)
self.filter_batch(sorted_indices) self.filter_batch(sorted_indices)
...@@ -363,7 +367,7 @@ class Batch: ...@@ -363,7 +367,7 @@ class Batch:
# insert the old request into tree_cache # insert the old request into tree_cache
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1] token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
if req_pool_indices_cpu is None: if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist() req_pool_indices_cpu = self.req_pool_indices.tolist()
req_pool_idx = req_pool_indices_cpu[i] req_pool_idx = req_pool_indices_cpu[i]
indices = self.req_to_token_pool.req_to_token[ indices = self.req_to_token_pool.req_to_token[
req_pool_idx, : len(token_ids_in_memory) req_pool_idx, : len(token_ids_in_memory)
......
...@@ -36,7 +36,9 @@ from sglang.srt.utils import ( ...@@ -36,7 +36,9 @@ from sglang.srt.utils import (
set_random_seed, set_random_seed,
) )
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN)
...@@ -54,9 +56,6 @@ class ModelRpcServer: ...@@ -54,9 +56,6 @@ class ModelRpcServer:
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
vllm_default_logger.setLevel(
level=getattr(logging, server_args.log_level.upper())
)
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
...@@ -65,7 +64,7 @@ class ModelRpcServer: ...@@ -65,7 +64,7 @@ class ModelRpcServer:
context_length=server_args.context_length, context_length=server_args.context_length,
) )
# for model end global settings # For model end global settings
server_args_dict = { server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer, "enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
...@@ -164,7 +163,7 @@ class ModelRpcServer: ...@@ -164,7 +163,7 @@ class ModelRpcServer:
logger.info("Cache flushed successfully!") logger.info("Cache flushed successfully!")
else: else:
warnings.warn( warnings.warn(
"Cache not flushed because there are pending requests. " f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, " f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
) )
...@@ -386,12 +385,12 @@ class ModelRpcServer: ...@@ -386,12 +385,12 @@ class ModelRpcServer:
f"#running_req: {running_req}. " f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
) )
logger.debug( #logger.debug(
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
) #)
new_batch = Batch.init_new( new_batch = Batch.init_new(
can_run_list, can_run_list,
...@@ -408,47 +407,41 @@ class ModelRpcServer: ...@@ -408,47 +407,41 @@ class ModelRpcServer:
self.model_config.vocab_size, self.int_token_logit_bias self.model_config.vocab_size, self.int_token_logit_bias
) )
prefill_token_logprobs = None
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward # Forward
logits, ( logits, (
prefill_token_logprobs, prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs, prefill_top_logprobs,
decode_top_logprobs, decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs, last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.EXTEND) ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
if prefill_token_logprobs is not None: if prefill_token_logprobs is not None:
prefill_token_logprobs = prefill_token_logprobs.cpu().tolist() prefill_token_logprobs = prefill_token_logprobs.tolist()
normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist() normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
next_token_ids, _ = batch.sample(logits) next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = (
last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
)
next_token_ids = next_token_ids.tolist()
else: else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
(
logits,
prefill_token_logprobs,
normalized_prompt_logprobs,
last_logprobs,
) = (None,) * 4
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
last_token_logprobs = None
if last_logprobs is not None:
last_token_logprobs = (
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
)
# Check finish condition # Check finish condition
pt = 0 pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]] req.output_ids = [next_token_ids[i]]
req.check_finished() req.check_finished()
if prefill_token_logprobs is not None: if req.return_logprob:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list( req.prefill_token_logprobs = list(
zip( zip(
...@@ -463,12 +456,14 @@ class ModelRpcServer: ...@@ -463,12 +456,14 @@ class ModelRpcServer:
req.decode_token_logprobs = [ req.decode_token_logprobs = [
(last_token_logprobs[i], next_token_ids[i]) (last_token_logprobs[i], next_token_ids[i])
] ]
if req.top_logprobs_num > 0:
req.prefill_top_logprobs = prefill_top_logprobs[i] req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0: if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.decode_top_logprobs = [decode_top_logprobs[i]] req.decode_top_logprobs = [decode_top_logprobs[i]]
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
pt += req.extend_input_len pt += req.extend_input_len
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
...@@ -520,29 +515,29 @@ class ModelRpcServer: ...@@ -520,29 +515,29 @@ class ModelRpcServer:
logits, ( logits, (
_, _,
_, _,
decode_top_logprobs,
_, _,
decode_top_logprobs,
last_logprobs, last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.DECODE) ) = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(logits) next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.tolist()
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
new_token_logprobs = None
if last_logprobs is not None: if last_logprobs is not None:
new_token_logprobs = last_logprobs[ new_token_logprobs = last_logprobs[
torch.arange(len(reqs)), next_token_ids torch.arange(len(batch.reqs)), next_token_ids
].tolist() ].tolist()
# Check finish condition # Check finish condition
for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
if new_token_logprobs is not None: if req.return_logprob:
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id)) req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(decode_top_logprobs[i]) req.decode_top_logprobs.append(decode_top_logprobs[i])
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
...@@ -590,8 +585,7 @@ class ModelRpcServer: ...@@ -590,8 +585,7 @@ class ModelRpcServer:
+ len(req.output_ids) + len(req.output_ids)
- req.prompt_tokens, - req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finish_reason), "finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
"hit_stop_str": req.hit_stop_str,
} }
if req.return_logprob: if req.return_logprob:
( (
...@@ -628,7 +622,7 @@ class ModelRpcServer: ...@@ -628,7 +622,7 @@ class ModelRpcServer:
# Remove finished reqs # Remove finished reqs
if finished_indices: if finished_indices:
# Update radix cache # Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist() req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices: for i in finished_indices:
req = batch.reqs[i] req = batch.reqs[i]
req_pool_idx = req_pool_indices_cpu[i] req_pool_idx = req_pool_indices_cpu[i]
......
...@@ -29,7 +29,7 @@ QUANTIZATION_CONFIG_MAPPING = { ...@@ -29,7 +29,7 @@ QUANTIZATION_CONFIG_MAPPING = {
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
# for server args in model endpoints # for server args in model endpoints
global_server_args_dict: dict = None global_server_args_dict = {}
@lru_cache() @lru_cache()
...@@ -86,8 +86,8 @@ class InputMetadata: ...@@ -86,8 +86,8 @@ class InputMetadata:
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None other_kv_index: torch.Tensor = None
top_logprobs_nums: List[int] = None
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None
# for flashinfer # for flashinfer
qo_indptr: torch.Tensor = None qo_indptr: torch.Tensor = None
...@@ -107,18 +107,20 @@ class InputMetadata: ...@@ -107,18 +107,20 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0) self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
seq_lens_cpu = self.seq_lens.cpu().numpy()
self.kv_indices = torch.cat( self.kv_indices = torch.cat(
[ [
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[
self.req_pool_indices[i].item(), : self.seq_lens[i].item() req_pool_indices_cpu[i]: seq_lens_cpu[i]
] ]
for i in range(self.batch_size) for i in range(self.batch_size)
], ],
dim=0, dim=0,
).contiguous() ).contiguous()
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
workspace_buffer = torch.empty( workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda" 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
...@@ -195,15 +197,15 @@ class InputMetadata: ...@@ -195,15 +197,15 @@ class InputMetadata:
req_pool_indices[0], seq_lens[0] - 1 req_pool_indices[0], seq_lens[0] - 1
].item() ].item()
else: else:
seq_lens_np = seq_lens.cpu().numpy() seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_np = prefix_lens.cpu().numpy() prefix_lens_cpu = prefix_lens.cpu().numpy()
position_ids_offsets_np = position_ids_offsets.cpu().numpy() position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
positions = torch.tensor( positions = torch.tensor(
np.concatenate( np.concatenate(
[ [
np.arange( np.arange(
prefix_lens_np[i] + position_ids_offsets_np[i], prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
seq_lens_np[i] + position_ids_offsets_np[i], seq_lens_cpu[i] + position_ids_offsets_cpu[i],
) )
for i in range(batch_size) for i in range(batch_size)
], ],
...@@ -229,9 +231,9 @@ class InputMetadata: ...@@ -229,9 +231,9 @@ class InputMetadata:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=out_cache_cont_end,
top_logprobs_nums=top_logprobs_nums,
return_logprob=return_logprob,
other_kv_index=other_kv_index, other_kv_index=other_kv_index,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
) )
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
......
...@@ -185,7 +185,10 @@ class TokenizerManager: ...@@ -185,7 +185,10 @@ class TokenizerManager:
while True: while True:
await event.wait() await event.wait()
yield state.out_list[-1] yield self.convert_logprob_style(state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs)
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
del self.rid_to_state[rid] del self.rid_to_state[rid]
...@@ -231,16 +234,16 @@ class TokenizerManager: ...@@ -231,16 +234,16 @@ class TokenizerManager:
rid = obj.rid[i] rid = obj.rid[i]
state = self.rid_to_state[rid] state = self.rid_to_state[rid]
await state.event.wait() await state.event.wait()
output_list.append(state.out_list[-1]) output_list.append(
self.convert_logprob_style(state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs))
assert state.finished assert state.finished
del self.rid_to_state[rid] del self.rid_to_state[rid]
yield output_list yield output_list
async def detokenize(self, obj: DetokenizeReqInput):
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
async def flush_cache(self): async def flush_cache(self):
flush_cache_req = FlushCacheReq() flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req) self.send_to_router.send_pyobj(flush_cache_req)
...@@ -267,3 +270,37 @@ class TokenizerManager: ...@@ -267,3 +270,37 @@ class TokenizerManager:
state.event.set() state.event.set()
else: else:
raise ValueError(f"Invalid object: {recv_obj}") raise ValueError(f"Invalid object: {recv_obj}")
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
token_ids = [tid for _, tid in token_logprobs]
token_texts = self.tokenizer.batch_decode(token_ids)
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
for i, t in enumerate(top_logprobs):
if t:
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
return top_logprobs
"""pydantic models for OpenAI API protocol"""
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
......
...@@ -10,7 +10,7 @@ import threading ...@@ -10,7 +10,7 @@ import threading
import time import time
from typing import List, Optional, Union from typing import List, Optional, Union
# Fix a Python bug # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp import aiohttp
...@@ -53,10 +53,10 @@ from sglang.srt.managers.router.manager import start_router_process ...@@ -53,10 +53,10 @@ from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
enable_show_time_cost,
allocate_init_ports, allocate_init_ports,
jsonify_pydantic_model,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost,
jsonify_pydantic_model,
get_exception_traceback, get_exception_traceback,
API_KEY_HEADER_NAME, API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware APIKeyValidatorMiddleware
...@@ -99,12 +99,6 @@ async def flush_cache(): ...@@ -99,12 +99,6 @@ async def flush_cache():
) )
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj):
await handle_token_logprobs_results(obj, out)
yield out
@app.post("/generate") @app.post("/generate")
async def generate_request(obj: GenerateReqInput): async def generate_request(obj: GenerateReqInput):
obj.post_init() obj.post_init()
...@@ -112,69 +106,16 @@ async def generate_request(obj: GenerateReqInput): ...@@ -112,69 +106,16 @@ async def generate_request(obj: GenerateReqInput):
if obj.stream: if obj.stream:
async def stream_results(): async def stream_results():
async for out in stream_generator(obj): async for out in tokenizer_manager.generate_request(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream") return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__() ret = await tokenizer_manager.generate_request(obj).__anext__()
await handle_token_logprobs_results(obj, ret)
return ret return ret
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
token_ids = [tid for _, tid in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]
async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text):
for i, t in enumerate(top_logprobs):
if top_logprobs[i] is not None:
top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text)
return top_logprobs
async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
"""Handle the token logprobs results, convert token ids to text if needed.
Args:
obj (GenerateReqInput): The request object.
ret (Union[Dict, List[Dict]]): The response object.
"""
# NOTE: This is because the multiple requests in one http request.
async def convert_style(r, return_text):
r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens(
r["meta_info"]["prefill_token_logprobs"], return_text
)
r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens(
r["meta_info"]["decode_token_logprobs"], return_text
)
r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens(
r["meta_info"]["prefill_top_logprobs"], return_text
)
r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens(
r["meta_info"]["decode_top_logprobs"], return_text
)
if isinstance(obj.text, str):
if obj.return_logprob:
await convert_style(ret, obj.return_text_in_logprobs)
else:
for i, r in enumerate(ret):
if obj.return_logprob[i]:
await convert_style(r, obj.return_text_in_logprobs)
@app.post("/v1/completions") @app.post("/v1/completions")
async def v1_completions(raw_request: Request): async def v1_completions(raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
...@@ -203,10 +144,10 @@ async def v1_completions(raw_request: Request): ...@@ -203,10 +144,10 @@ async def v1_completions(raw_request: Request):
if adapted_request.stream: if adapted_request.stream:
async def gnerate_stream_resp(): async def generate_stream_resp():
stream_buffer = "" stream_buffer = ""
n_prev_token = 0 n_prev_token = 0
async for content in stream_generator(adapted_request): async for content in tokenizer_manager.generate_request(adapted_request):
text = content["text"] text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"] prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"] completion_tokens = content["meta_info"]["completion_tokens"]
...@@ -266,7 +207,7 @@ async def v1_completions(raw_request: Request): ...@@ -266,7 +207,7 @@ async def v1_completions(raw_request: Request):
yield f"data: {jsonify_pydantic_model(chunk)}\n\n" yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream") return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
# Non-streaming response. # Non-streaming response.
ret = await generate_request(adapted_request) ret = await generate_request(adapted_request)
...@@ -384,7 +325,7 @@ async def v1_chat_completions(raw_request: Request): ...@@ -384,7 +325,7 @@ async def v1_chat_completions(raw_request: Request):
is_first = True is_first = True
stream_buffer = "" stream_buffer = ""
async for content in stream_generator(adapted_request): async for content in tokenizer_manager.generate_request(adapted_request):
if is_first: if is_first:
# First chunk with role # First chunk with role
is_first = False is_first = False
......
...@@ -241,7 +241,7 @@ class ServerArgs: ...@@ -241,7 +241,7 @@ class ServerArgs:
def print_mode_args(self): def print_mode_args(self):
return ( return (
f"enable_flashinfer={self.enable_flashinfer}, " f"enable_flashinfer={self.enable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}" f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
f"disable_radix_cache={self.disable_radix_cache}, " f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, " f"disable_disk_cache={self.disable_disk_cache}, "
......
"""Common utilities."""
import base64 import base64
import os import os
import random import random
...@@ -13,6 +15,7 @@ import numpy as np ...@@ -13,6 +15,7 @@ import numpy as np
import pydantic import pydantic
import requests import requests
import torch import torch
from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from pydantic import BaseModel from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
...@@ -303,6 +306,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ...@@ -303,6 +306,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
response = await call_next(request) response = await call_next(request)
return response return response
# FIXME: Remove this once we drop support for pydantic 1.x # FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
...@@ -310,4 +314,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 ...@@ -310,4 +314,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel): def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1: if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False) return obj.json(ensure_ascii=False)
return obj.model_dump_json() return obj.model_dump_json()
\ No newline at end of file
...@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True): ...@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
def test_image_qa(): def test_image_qa():
@sgl.function @sgl.function
def image_qa(s, question): def image_qa(s, question):
s += sgl.user(sgl.image("test_image.png") + question) s += sgl.user(sgl.image("example_image.png") + question)
s += sgl.assistant(sgl.gen("answer")) s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run( state = image_qa.run(
......
...@@ -28,7 +28,7 @@ class TestOpenAIBackend(unittest.TestCase): ...@@ -28,7 +28,7 @@ class TestOpenAIBackend(unittest.TestCase):
if cls.backend is None: if cls.backend is None:
cls.backend = OpenAI("gpt-3.5-turbo-instruct") cls.backend = OpenAI("gpt-3.5-turbo-instruct")
cls.chat_backend = OpenAI("gpt-3.5-turbo") cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.chat_vision_backend = OpenAI("gpt-4-vision-preview") cls.chat_vision_backend = OpenAI("gpt-4-turbo")
def test_few_shot_qa(self): def test_few_shot_qa(self):
set_default_backend(self.backend) set_default_backend(self.backend)
...@@ -88,14 +88,3 @@ if __name__ == "__main__": ...@@ -88,14 +88,3 @@ if __name__ == "__main__":
# t = TestOpenAIBackend() # t = TestOpenAIBackend()
# t.setUp() # t.setUp()
# t.test_few_shot_qa() # t.test_few_shot_qa()
# t.test_mt_bench()
# t.test_select()
# t.test_decode_int()
# t.test_decode_json()
# t.test_expert_answer()
# t.test_tool_use()
# t.test_react()
# t.test_parallel_decoding()
# t.test_parallel_encoding()
# t.test_image_qa()
# t.test_stream()
from sglang import OpenAI, function, gen, set_default_backend
@function()
def gen_character_default(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_no_stop(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday")
s += "\nJob:" + gen("job") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_multi_stop(s):
s += "Construct a character within the following format:\n"
s += (
"Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n"
)
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop=["\n", "###"])
s += "###Birthday:" + gen("birthday", stop=["\n", "###"])
s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = gen_character_default.run()
print(state.text())
print("=" * 60)
state = gen_character_no_stop.run()
print("name###", state["name"])
print("birthday###:", state["birthday"])
print("job###", state["job"])
print("=" * 60)
state = gen_character_multi_stop.run()
print(state.text())
print("=" * 60)
state = gen_character_spec.run()
print(state.text())
print("name###", state["name"])
print("birthday###", state["birthday"])
print("job###", state["job"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment