Unverified Commit a7334aee authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Support decode token logprobs (#130)

parent ee1df26a
...@@ -14,18 +14,25 @@ class LogitsProcessor(nn.Module): ...@@ -14,18 +14,25 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def forward(self, input_ids, hidden_states, weight, input_metadata): def forward(self, input_ids, hidden_states, weight, input_metadata):
last_index = None
# Compute the last index (the first decode token) of each requeast
# if we are in prefill or extend mode.
if input_metadata.forward_mode != ForwardMode.DECODE:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
if not input_metadata.return_logprob: if not input_metadata.return_logprob:
# When logprob is not requested, only compute the last logits.
if input_metadata.forward_mode == ForwardMode.DECODE: if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states last_hidden = hidden_states
else: else:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
last_hidden = hidden_states[last_index] last_hidden = hidden_states[last_index]
hidden_states = None hidden_states = None
...@@ -33,41 +40,42 @@ class LogitsProcessor(nn.Module): ...@@ -33,41 +40,42 @@ class LogitsProcessor(nn.Module):
if self.tp_size > 1: if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size] last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, (None, None) return last_logits, (None, None, None)
else: else:
assert input_metadata.forward_mode != ForwardMode.DECODE # When logprob is requested, compute the logits for all tokens.
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
logits = torch.matmul(hidden_states, weight.T) logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1: if self.tp_size > 1:
logits = tensor_model_parallel_all_gather(logits) logits = tensor_model_parallel_all_gather(logits)
logits = logits[:, : self.config.vocab_size] logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6) all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
logprobs = all_logprobs[ if input_metadata.forward_mode == ForwardMode.DECODE:
torch.arange(all_logprobs.shape[0], device="cuda"), last_logits = logits
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), last_logprobs = all_logprobs
] prefill_logprobs = normalized_logprobs = None
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) else:
# Compute the logprobs for the last token of each request.
last_logits = logits[last_index]
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
start = input_metadata.extend_start_loc.clone() start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2 end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1) start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1) end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
normalized_logprobs = sum_logp / ( normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1) (input_metadata.extend_seq_lens - 1).clamp(min=1)
) )
last_logits = logits[last_index] return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
return last_logits, (logprobs, normalized_logprobs)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -99,3 +99,7 @@ class BatchStrOut: ...@@ -99,3 +99,7 @@ class BatchStrOut:
@dataclass @dataclass
class FlushCacheReq: class FlushCacheReq:
pass pass
@dataclass
class DetokenizeReqInput:
input_ids: List[int]
...@@ -48,6 +48,7 @@ class Req: ...@@ -48,6 +48,7 @@ class Req:
self.last_node = None self.last_node = None
self.logprob = None self.logprob = None
self.token_logprob = None
self.normalized_logprob = None self.normalized_logprob = None
# For constrained decoding # For constrained decoding
......
...@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service): ...@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service):
self.model_config.vocab_size, self.int_token_logit_bias self.model_config.vocab_size, self.int_token_logit_bias
) )
logprobs = None
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward # Forward
logits, (logprobs, normalized_logprobs) = self.model_runner.forward( logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
batch, ForwardMode.EXTEND, batch.return_logprob self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
) )
# print("extend logits", logits) if prefill_logprobs is not None:
if logprobs is not None: logprobs = prefill_logprobs.cpu().tolist()
logprobs = logprobs.cpu().tolist()
normalized_logprobs = normalized_logprobs.cpu().tolist() normalized_logprobs = normalized_logprobs.cpu().tolist()
next_token_ids, next_token_probs = batch.sample(logits) next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.cpu().tolist()
else: else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
logprobs = normalized_logprobs = None logits = logprobs = normalized_logprobs = last_logprobs = None
# Check finish condition # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs reqs = batch.reqs
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
# Check finish condition
pt = 0 pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
req.output_ids = [next_token_ids[i]] req.output_ids = [next_token_ids[i]]
...@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service): ...@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service):
if logprobs is not None: if logprobs is not None:
req.logprob = logprobs[pt : pt + req.extend_input_len - 1] req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
req.normalized_logprob = normalized_logprobs[i] req.normalized_logprob = normalized_logprobs[i]
token_ids = req.input_ids + [next_token_ids[i]]
token_logprobs = [None] + req.logprob + [last_logprobs[i]]
req.token_logprob = list(zip(token_ids, token_logprobs))
pt += req.extend_input_len pt += req.extend_input_len
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
...@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service): ...@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service):
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward # Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE) logits, (_, _, last_logprobs) = self.model_runner.forward(
next_token_ids, next_token_probs = batch.sample(logits) batch,
ForwardMode.DECODE,
batch.return_logprob,
)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.cpu().tolist()
# Check finish condition # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs reqs = batch.reqs
for i in range(len(reqs)): if last_logprobs is not None:
reqs[i].output_ids.append(next_token_ids[i]) last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
reqs[i].check_finished()
# Check finish condition
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
req.output_ids.append(next_tok_id)
req.check_finished()
if last_logprobs is not None:
req.token_logprob.append((next_tok_id, last_logprobs[i]))
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
...@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service):
} }
if req.return_logprob: if req.return_logprob:
meta_info["prompt_logprob"] = req.logprob meta_info["prompt_logprob"] = req.logprob
meta_info["token_logprob"] = req.token_logprob
meta_info["normalized_prompt_logprob"] = req.normalized_logprob meta_info["normalized_prompt_logprob"] = req.normalized_logprob
output_meta_info.append(meta_info) output_meta_info.append(meta_info)
output_finished.append(req.finished) output_finished.append(req.finished)
......
...@@ -397,6 +397,7 @@ class ModelRunner: ...@@ -397,6 +397,7 @@ class ModelRunner:
out_cache_loc, out_cache_loc,
out_cache_cont_start, out_cache_cont_start,
out_cache_cont_end, out_cache_cont_end,
return_logprob,
): ):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
...@@ -409,10 +410,9 @@ class ModelRunner: ...@@ -409,10 +410,9 @@ class ModelRunner:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=out_cache_cont_end,
return_logprob=return_logprob,
) )
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[ return self.model.forward(input_ids, input_metadata.positions, input_metadata)
0
]
@torch.inference_mode() @torch.inference_mode()
def forward_extend_multi_modal( def forward_extend_multi_modal(
...@@ -460,8 +460,8 @@ class ModelRunner: ...@@ -460,8 +460,8 @@ class ModelRunner:
"prefix_lens": batch.prefix_lens, "prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets, "position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc, "out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
} }
kwargs["return_logprob"] = return_logprob
return self.forward_extend_multi_modal(**kwargs) return self.forward_extend_multi_modal(**kwargs)
else: else:
kwargs = { kwargs = {
...@@ -471,6 +471,7 @@ class ModelRunner: ...@@ -471,6 +471,7 @@ class ModelRunner:
"prefix_lens": batch.prefix_lens, "prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets, "position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc, "out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
} }
if forward_mode == ForwardMode.DECODE: if forward_mode == ForwardMode.DECODE:
...@@ -478,10 +479,8 @@ class ModelRunner: ...@@ -478,10 +479,8 @@ class ModelRunner:
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs) return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND: elif forward_mode == ForwardMode.EXTEND:
kwargs["return_logprob"] = return_logprob
return self.forward_extend(**kwargs) return self.forward_extend(**kwargs)
elif forward_mode == ForwardMode.PREFILL: elif forward_mode == ForwardMode.PREFILL:
kwargs["return_logprob"] = return_logprob
return self.forward_prefill(**kwargs) return self.forward_prefill(**kwargs)
else: else:
raise ValueError(f"Invaid forward mode: {forward_mode}") raise ValueError(f"Invaid forward mode: {forward_mode}")
...@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchStrOut, BatchStrOut,
DetokenizeReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -234,6 +235,10 @@ class TokenizerManager: ...@@ -234,6 +235,10 @@ class TokenizerManager:
yield output_list yield output_list
async def detokenize(self, obj: DetokenizeReqInput):
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
async def flush_cache(self): async def flush_cache(self):
flush_cache_req = FlushCacheReq() flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req) self.send_to_router.send_pyobj(flush_cache_req)
......
...@@ -30,7 +30,7 @@ from sglang.srt.conversation import ( ...@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
) )
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
from sglang.srt.managers.openai_protocol import ( from sglang.srt.managers.openai_protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
...@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import ( ...@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
DeltaMessage, DeltaMessage,
LogProbs,
UsageInfo, UsageInfo,
) )
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
...@@ -97,6 +98,23 @@ async def stream_generator(obj): ...@@ -97,6 +98,23 @@ async def stream_generator(obj):
yield out yield out
async def make_openai_style_logprobs(token_logprobs):
ret_logprobs = LogProbs()
# Detokenize
token_ids = [tid for tid, _ in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(token_logprob)
# Not supported yet.
ret_logprobs.top_logprobs.append({})
ret_logprobs.text_offset.append(-1)
return ret_logprobs
@app.post("/generate") @app.post("/generate")
async def generate_request(obj: GenerateReqInput): async def generate_request(obj: GenerateReqInput):
obj.post_init() obj.post_init()
...@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request): ...@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
"presence_penalty": request.presence_penalty, "presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty, "frequency_penalty": request.frequency_penalty,
}, },
return_logprob=request.logprobs is not None,
stream=request.stream, stream=request.stream,
) )
adapted_request.post_init() adapted_request.post_init()
...@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request): ...@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
async def gnerate_stream_resp(): async def gnerate_stream_resp():
stream_buffer = "" stream_buffer = ""
n_prev_token = 0
async for content in stream_generator(adapted_request): async for content in stream_generator(adapted_request):
text = content["text"] text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"] prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"] completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
else:
# Skip prompt tokens if echo is disabled.
n_prev_token = prompt_tokens
if request.logprobs is not None:
logprobs = await make_openai_style_logprobs(
content["meta_info"]["token_logprob"][n_prev_token:]
)
n_prev_token = len(content["meta_info"]["token_logprob"])
else:
logprobs = None
delta = text[len(stream_buffer) :] delta = text[len(stream_buffer) :]
stream_buffer = text stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=0, index=0,
text=delta, text=delta,
logprobs=None, logprobs=logprobs,
finish_reason=None, finish_reason=None,
) )
chunk = CompletionStreamResponse( chunk = CompletionStreamResponse(
...@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request): ...@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
# Non-streaming response. # Non-streaming response.
ret = await generate_request(adapted_request) ret = await generate_request(adapted_request)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"]
token_logprob_pos = prompt_tokens
if request.echo:
token_logprob_pos = 0
text = request.prompt + text
else:
token_logprob_pos = prompt_tokens
logprobs = (
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
if request.logprobs is not None
else None
)
choice_data = CompletionResponseChoice( choice_data = CompletionResponseChoice(
index=0, index=0,
text=ret["text"], text=text,
logprobs=None, logprobs=logprobs,
finish_reason=None, # TODO(comaniac): Add finish reason. finish_reason=None, # TODO(comaniac): Add finish reason.
) )
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
response = CompletionResponse( response = CompletionResponse(
id=ret["meta_info"]["id"], id=ret["meta_info"]["id"],
model=request.model, model=request.model,
...@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request): ...@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
if not isinstance(m.content, str): if not isinstance(m.content, str):
raise HTTPException( raise HTTPException(
status_code=503, status_code=503,
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.", detail="Structured content requests not supported with "
"HuggingFace Chat Templates. "
"Make sure the server specifies a sglang chat template.",
) )
prompt = tokenizer_manager.tokenizer.apply_chat_template( prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True request.messages, tokenize=False, add_generation_prompt=True
......
...@@ -9,18 +9,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington, ...@@ -9,18 +9,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
""" """
import argparse import argparse
import time
import requests import requests
if __name__ == "__main__": def test_decode(url, return_logprob):
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
json={ json={
...@@ -29,8 +21,19 @@ if __name__ == "__main__": ...@@ -29,8 +21,19 @@ if __name__ == "__main__":
"temperature": 0, "temperature": 0,
"max_new_tokens": 32, "max_new_tokens": 32,
}, },
# "return_logprob": True, "return_logprob": return_logprob,
# "logprob_start_len": 0, "logprob_start_len": 0,
}, },
) )
print(response.json()) print(response.json())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode(url, False)
test_decode(url, True)
...@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington, ...@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
import argparse import argparse
import json import json
import time
import requests import requests
if __name__ == "__main__": def test_decode_stream(url, return_logprob):
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
json={ json={
"text": "The capital of France is", "text": "The capital of France is",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 512, "max_new_tokens": 128,
}, },
"stream": True, "stream": True,
"return_logprob": return_logprob,
}, },
stream=True, stream=True,
) )
...@@ -41,7 +34,29 @@ if __name__ == "__main__": ...@@ -41,7 +34,29 @@ if __name__ == "__main__":
if chunk == "data: [DONE]": if chunk == "data: [DONE]":
break break
data = json.loads(chunk[5:].strip("\n")) data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True) if return_logprob:
prev = len(output) assert data["meta_info"]["prompt_logprob"] is not None
assert data["meta_info"]["token_logprob"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
if prev == 0: # Skip prompt logprobs
prev = data["meta_info"]["prompt_tokens"]
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
print(f"{token_txt}\t{logprob}", flush=True)
prev = len(data["meta_info"]["token_logprob"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("") print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode_stream(url, False)
test_decode_stream(url, True)
...@@ -18,15 +18,26 @@ import argparse ...@@ -18,15 +18,26 @@ import argparse
import openai import openai
def test_completion(args): def test_completion(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url) client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create( response = client.completions.create(
model="default", model="default",
prompt="The capital of France is", prompt="The capital of France is",
temperature=0, temperature=0,
max_tokens=32, max_tokens=32,
echo=echo,
logprobs=logprobs,
) )
text = response.choices[0].text
print(response.choices[0].text) print(response.choices[0].text)
if echo:
assert text.startswith("The capital of France is")
if logprobs:
assert response.choices[0].logprobs
if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id assert response.id
assert response.created assert response.created
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0
...@@ -34,7 +45,7 @@ def test_completion(args): ...@@ -34,7 +45,7 @@ def test_completion(args):
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def test_completion_stream(args): def test_completion_stream(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url) client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create( response = client.completions.create(
model="default", model="default",
...@@ -42,9 +53,23 @@ def test_completion_stream(args): ...@@ -42,9 +53,23 @@ def test_completion_stream(args):
temperature=0, temperature=0,
max_tokens=32, max_tokens=32,
stream=True, stream=True,
echo=echo,
logprobs=logprobs,
) )
first = True
for r in response: for r in response:
print(r.choices[0].text, end="", flush=True) if first:
if echo:
assert r.choices[0].text.startswith("The capital of France is")
first = False
if logprobs:
print(
f"{r.choices[0].text:12s}\t"
f"{r.choices[0].logprobs.token_logprobs}",
flush=True
)
else:
print(r.choices[0].text, end="", flush=True)
assert r.id assert r.id
assert r.usage.prompt_tokens > 0 assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0 assert r.usage.completion_tokens > 0
...@@ -135,8 +160,14 @@ if __name__ == "__main__": ...@@ -135,8 +160,14 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
test_completion(args) test_completion(args, echo=False, logprobs=False)
test_completion_stream(args) test_completion(args, echo=True, logprobs=False)
test_completion(args, echo=False, logprobs=True)
test_completion(args, echo=True, logprobs=True)
test_completion_stream(args, echo=False, logprobs=False)
test_completion_stream(args, echo=True, logprobs=False)
test_completion_stream(args, echo=False, logprobs=True)
test_completion_stream(args, echo=True, logprobs=True)
test_chat_completion(args) test_chat_completion(args)
test_chat_completion_stream(args) test_chat_completion_stream(args)
if args.test_image: if args.test_image:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment