Unverified Commit aee4f523 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix logit processor bugs (#427)

parent 7023f413
......@@ -297,7 +297,6 @@ curl http://localhost:30000/generate \
Learn more about the argument format [here](docs/sampling_params.md).
### OpenAI Compatible API
In addition, the server supports an experimental OpenAI-compatible API.
```python
......@@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
## Benchmark And Performance
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
![llama_7b](assets/llama_7b.jpg)
......@@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157
}
```
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/2312.07104)
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
"""Some Public API Definitions"""
import os
import re
from typing import Callable, List, Optional, Union
......@@ -31,6 +32,7 @@ def function(
def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from sglang.srt.server import Runtime
return Runtime(*args, **kwargs)
......
......@@ -14,7 +14,7 @@ except ImportError as e:
class Anthropic(BaseBackend):
def __init__(self, model_name):
def __init__(self, model_name, *args, **kwargs):
super().__init__()
if isinstance(anthropic, Exception):
......@@ -22,6 +22,7 @@ class Anthropic(BaseBackend):
self.model_name = model_name
self.chat_template = get_chat_template("claude")
self.client = anthropic.Anthropic(*args, **kwargs)
def get_chat_template(self):
return self.chat_template
......@@ -41,7 +42,7 @@ class Anthropic(BaseBackend):
else:
system = ""
ret = anthropic.Anthropic().messages.create(
ret = self.client.messages.create(
model=self.model_name,
system=system,
messages=messages,
......@@ -66,11 +67,11 @@ class Anthropic(BaseBackend):
else:
system = ""
with anthropic.Anthropic().messages.stream(
with self.client.messages.stream(
model=self.model_name,
system=system,
messages=messages,
**sampling_params.to_anthropic_kwargs(),
) as stream:
for text in stream.text_stream:
yield text, {}
yield text, {}
\ No newline at end of file
......@@ -228,7 +228,7 @@ class OpenAI(BaseBackend):
prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)]
return decision, scores, scores
return decision, scores, None, None
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
......
......@@ -220,7 +220,6 @@ class RuntimeEndpoint(BaseBackend):
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
"return_text_in_logprobs": True,
}
self._add_images(s, data)
res = http_request(
......
......@@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module):
for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.cpu().tolist()
p_cpu = t.indices.cpu().tolist()
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs
else:
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens
for i in range(len(input_metadata.extend_seq_lens)):
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
for i in range(len(extend_seq_lens_cpu)):
if extend_seq_lens_cpu[i] == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
continue
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
vs_cpu = t.values.cpu().tolist()
ps_cpu = t.indices.cpu().tolist()
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_lens_cpu[i]
return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
......@@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module):
all_logits = all_logits[:, : self.config.vocab_size]
all_logprobs = all_logits.float()
all_logits = None
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
)
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = all_logprobs
return last_logits, (
None,
None,
decode_top_logprobs,
None,
decode_top_logprobs,
last_logprobs,
)
else:
......@@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module):
)
return last_logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs,
)
......
......@@ -25,7 +25,6 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False
# Whether to stream output
stream: bool = False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def post_init(self):
is_single = isinstance(self.text, str)
......
from dataclasses import dataclass
from enum import Enum, auto
from enum import IntEnum, auto
from typing import List
import numpy as np
......@@ -9,15 +9,15 @@ from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
class ForwardMode(Enum):
class ForwardMode(IntEnum):
PREFILL = auto()
EXTEND = auto()
DECODE = auto()
class FinishReason(Enum):
LENGTH = auto()
class FinishReason(IntEnum):
EOS_TOKEN = auto()
LENGTH = auto()
STOP_STR = auto()
......@@ -31,6 +31,7 @@ class Req:
# Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage.
self.prompt_tokens = len(input_ids)
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
......@@ -41,12 +42,11 @@ class Req:
self.image_offset = 0
self.pad_value = None
# Sampling parameters
self.sampling_params = None
self.return_logprob = False
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.stream = False
# Check finish
self.tokenizer = None
self.finished = False
self.finish_reason = None
......@@ -56,13 +56,17 @@ class Req:
self.prefix_indices = []
self.last_node = None
# Logprobs
self.return_logprob = False
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None
self.decode_token_logprobs = None
self.normalized_prompt_logprob = None
self.prefill_top_logprobs = None
self.decode_top_logprobs = None
# For constrained decoding
# Constrained decoding
self.regex_fsm = None
self.regex_fsm_state = 0
self.jump_forward_map = None
......@@ -165,8 +169,8 @@ class Batch:
out_cache_cont_end: torch.Tensor = None
# for processing logprobs
top_logprobs_nums: List[int] = None
return_logprob: bool = False
top_logprobs_nums: List[int] = None
# for multimodal
pixel_values: List[torch.Tensor] = None
......@@ -321,8 +325,8 @@ class Batch:
)
retracted_reqs = []
seq_lens_np = self.seq_lens.cpu().numpy()
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
seq_lens_cpu = self.seq_lens.cpu().numpy()
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
while self.token_to_kv_pool.available_size() < len(self.reqs):
idx = sorted_indices.pop()
req = self.reqs[idx]
......@@ -338,8 +342,8 @@ class Batch:
# TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx]
][: seq_lens_np[idx]]
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)
self.filter_batch(sorted_indices)
......@@ -363,7 +367,7 @@ class Batch:
# insert the old request into tree_cache
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
req_pool_indices_cpu = self.req_pool_indices.tolist()
req_pool_idx = req_pool_indices_cpu[i]
indices = self.req_to_token_pool.req_to_token[
req_pool_idx, : len(token_ids_in_memory)
......
......@@ -36,7 +36,9 @@ from sglang.srt.utils import (
set_random_seed,
)
logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
......@@ -54,9 +56,6 @@ class ModelRpcServer:
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
vllm_default_logger.setLevel(
level=getattr(logging, server_args.log_level.upper())
)
# Init model and tokenizer
self.model_config = ModelConfig(
......@@ -65,7 +64,7 @@ class ModelRpcServer:
context_length=server_args.context_length,
)
# for model end global settings
# For model end global settings
server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
......@@ -164,7 +163,7 @@ class ModelRpcServer:
logger.info("Cache flushed successfully!")
else:
warnings.warn(
"Cache not flushed because there are pending requests. "
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
......@@ -386,12 +385,12 @@ class ModelRpcServer:
f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
)
logger.debug(
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
)
#logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
#)
new_batch = Batch.init_new(
can_run_list,
......@@ -408,47 +407,41 @@ class ModelRpcServer:
self.model_config.vocab_size, self.int_token_logit_bias
)
prefill_token_logprobs = None
if batch.extend_num_tokens != 0:
# Forward
logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
if prefill_token_logprobs is not None:
prefill_token_logprobs = prefill_token_logprobs.cpu().tolist()
normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist()
prefill_token_logprobs = prefill_token_logprobs.tolist()
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = (
last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
)
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
(
logits,
prefill_token_logprobs,
normalized_prompt_logprobs,
last_logprobs,
) = (None,) * 4
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
last_token_logprobs = None
if last_logprobs is not None:
last_token_logprobs = (
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
)
# Check finish condition
pt = 0
for i, req in enumerate(reqs):
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]]
req.check_finished()
if prefill_token_logprobs is not None:
if req.return_logprob:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
zip(
......@@ -463,12 +456,14 @@ class ModelRpcServer:
req.decode_token_logprobs = [
(last_token_logprobs[i], next_token_ids[i])
]
if req.top_logprobs_num > 0:
req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.decode_top_logprobs = [decode_top_logprobs[i]]
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
pt += req.extend_input_len
pt += req.extend_input_len
self.handle_finished_requests(batch)
......@@ -520,29 +515,29 @@ class ModelRpcServer:
logits, (
_,
_,
decode_top_logprobs,
_,
decode_top_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
next_token_ids = next_token_ids.tolist()
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
new_token_logprobs = None
if last_logprobs is not None:
new_token_logprobs = last_logprobs[
torch.arange(len(reqs)), next_token_ids
torch.arange(len(batch.reqs)), next_token_ids
].tolist()
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)):
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
if new_token_logprobs is not None:
if req.return_logprob:
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(decode_top_logprobs[i])
self.handle_finished_requests(batch)
......@@ -590,8 +585,7 @@ class ModelRpcServer:
+ len(req.output_ids)
- req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finish_reason),
"hit_stop_str": req.hit_stop_str,
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
}
if req.return_logprob:
(
......@@ -628,7 +622,7 @@ class ModelRpcServer:
# Remove finished reqs
if finished_indices:
# Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices:
req = batch.reqs[i]
req_pool_idx = req_pool_indices_cpu[i]
......
......@@ -29,7 +29,7 @@ QUANTIZATION_CONFIG_MAPPING = {
logger = logging.getLogger("model_runner")
# for server args in model endpoints
global_server_args_dict: dict = None
global_server_args_dict = {}
@lru_cache()
......@@ -86,8 +86,8 @@ class InputMetadata:
out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
top_logprobs_nums: List[int] = None
return_logprob: bool = False
top_logprobs_nums: List[int] = None
# for flashinfer
qo_indptr: torch.Tensor = None
......@@ -107,18 +107,20 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
seq_lens_cpu = self.seq_lens.cpu().numpy()
self.kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
req_pool_indices_cpu[i]: seq_lens_cpu[i]
]
for i in range(self.batch_size)
],
dim=0,
).contiguous()
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
......@@ -195,15 +197,15 @@ class InputMetadata:
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_np = seq_lens.cpu().numpy()
prefix_lens_np = prefix_lens.cpu().numpy()
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy()
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_np[i] + position_ids_offsets_np[i],
seq_lens_np[i] + position_ids_offsets_np[i],
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
)
for i in range(batch_size)
],
......@@ -229,9 +231,9 @@ class InputMetadata:
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
top_logprobs_nums=top_logprobs_nums,
return_logprob=return_logprob,
other_kv_index=other_kv_index,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
)
if forward_mode == ForwardMode.EXTEND:
......
......@@ -185,7 +185,10 @@ class TokenizerManager:
while True:
await event.wait()
yield state.out_list[-1]
yield self.convert_logprob_style(state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs)
state.out_list = []
if state.finished:
del self.rid_to_state[rid]
......@@ -231,16 +234,16 @@ class TokenizerManager:
rid = obj.rid[i]
state = self.rid_to_state[rid]
await state.event.wait()
output_list.append(state.out_list[-1])
output_list.append(
self.convert_logprob_style(state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs))
assert state.finished
del self.rid_to_state[rid]
yield output_list
async def detokenize(self, obj: DetokenizeReqInput):
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
async def flush_cache(self):
flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req)
......@@ -267,3 +270,37 @@ class TokenizerManager:
state.event.set()
else:
raise ValueError(f"Invalid object: {recv_obj}")
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
token_ids = [tid for _, tid in token_logprobs]
token_texts = self.tokenizer.batch_decode(token_ids)
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
for i, t in enumerate(top_logprobs):
if t:
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
return top_logprobs
"""pydantic models for OpenAI API protocol"""
import time
from typing import Dict, List, Optional, Union
......
......@@ -10,7 +10,7 @@ import threading
import time
from typing import List, Optional, Union
# Fix a Python bug
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp
......@@ -53,10 +53,10 @@ from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
enable_show_time_cost,
allocate_init_ports,
jsonify_pydantic_model,
assert_pkg_version,
enable_show_time_cost,
jsonify_pydantic_model,
get_exception_traceback,
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware
......@@ -99,12 +99,6 @@ async def flush_cache():
)
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj):
await handle_token_logprobs_results(obj, out)
yield out
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
......@@ -112,69 +106,16 @@ async def generate_request(obj: GenerateReqInput):
if obj.stream:
async def stream_results():
async for out in stream_generator(obj):
async for out in tokenizer_manager.generate_request(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__()
await handle_token_logprobs_results(obj, ret)
return ret
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
token_ids = [tid for _, tid in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]
async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text):
for i, t in enumerate(top_logprobs):
if top_logprobs[i] is not None:
top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text)
return top_logprobs
async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
"""Handle the token logprobs results, convert token ids to text if needed.
Args:
obj (GenerateReqInput): The request object.
ret (Union[Dict, List[Dict]]): The response object.
"""
# NOTE: This is because the multiple requests in one http request.
async def convert_style(r, return_text):
r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens(
r["meta_info"]["prefill_token_logprobs"], return_text
)
r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens(
r["meta_info"]["decode_token_logprobs"], return_text
)
r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens(
r["meta_info"]["prefill_top_logprobs"], return_text
)
r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens(
r["meta_info"]["decode_top_logprobs"], return_text
)
if isinstance(obj.text, str):
if obj.return_logprob:
await convert_style(ret, obj.return_text_in_logprobs)
else:
for i, r in enumerate(ret):
if obj.return_logprob[i]:
await convert_style(r, obj.return_text_in_logprobs)
@app.post("/v1/completions")
async def v1_completions(raw_request: Request):
request_json = await raw_request.json()
......@@ -203,10 +144,10 @@ async def v1_completions(raw_request: Request):
if adapted_request.stream:
async def gnerate_stream_resp():
async def generate_stream_resp():
stream_buffer = ""
n_prev_token = 0
async for content in stream_generator(adapted_request):
async for content in tokenizer_manager.generate_request(adapted_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
......@@ -266,7 +207,7 @@ async def v1_completions(raw_request: Request):
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
......@@ -384,7 +325,7 @@ async def v1_chat_completions(raw_request: Request):
is_first = True
stream_buffer = ""
async for content in stream_generator(adapted_request):
async for content in tokenizer_manager.generate_request(adapted_request):
if is_first:
# First chunk with role
is_first = False
......
......@@ -241,7 +241,7 @@ class ServerArgs:
def print_mode_args(self):
return (
f"enable_flashinfer={self.enable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
......
"""Common utilities."""
import base64
import os
import random
......@@ -13,6 +15,7 @@ import numpy as np
import pydantic
import requests
import torch
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
......@@ -303,6 +306,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
return response
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
......@@ -310,4 +314,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()
return obj.model_dump_json()
\ No newline at end of file
......@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
def test_image_qa():
@sgl.function
def image_qa(s, question):
s += sgl.user(sgl.image("test_image.png") + question)
s += sgl.user(sgl.image("example_image.png") + question)
s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run(
......
......@@ -28,7 +28,7 @@ class TestOpenAIBackend(unittest.TestCase):
if cls.backend is None:
cls.backend = OpenAI("gpt-3.5-turbo-instruct")
cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.chat_vision_backend = OpenAI("gpt-4-vision-preview")
cls.chat_vision_backend = OpenAI("gpt-4-turbo")
def test_few_shot_qa(self):
set_default_backend(self.backend)
......@@ -88,14 +88,3 @@ if __name__ == "__main__":
# t = TestOpenAIBackend()
# t.setUp()
# t.test_few_shot_qa()
# t.test_mt_bench()
# t.test_select()
# t.test_decode_int()
# t.test_decode_json()
# t.test_expert_answer()
# t.test_tool_use()
# t.test_react()
# t.test_parallel_decoding()
# t.test_parallel_encoding()
# t.test_image_qa()
# t.test_stream()
from sglang import OpenAI, function, gen, set_default_backend
@function()
def gen_character_default(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_no_stop(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday")
s += "\nJob:" + gen("job") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_multi_stop(s):
s += "Construct a character within the following format:\n"
s += (
"Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n"
)
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop=["\n", "###"])
s += "###Birthday:" + gen("birthday", stop=["\n", "###"])
s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = gen_character_default.run()
print(state.text())
print("=" * 60)
state = gen_character_no_stop.run()
print("name###", state["name"])
print("birthday###:", state["birthday"])
print("job###", state["job"])
print("=" * 60)
state = gen_character_multi_stop.run()
print(state.text())
print("=" * 60)
state = gen_character_spec.run()
print(state.text())
print("name###", state["name"])
print("birthday###", state["birthday"])
print("job###", state["job"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment