Unverified Commit a7334aee authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Support decode token logprobs (#130)

parent ee1df26a
......@@ -14,18 +14,25 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
def forward(self, input_ids, hidden_states, weight, input_metadata):
last_index = None
# Compute the last index (the first decode token) of each requeast
# if we are in prefill or extend mode.
if input_metadata.forward_mode != ForwardMode.DECODE:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
if not input_metadata.return_logprob:
# When logprob is not requested, only compute the last logits.
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
last_hidden = hidden_states[last_index]
hidden_states = None
......@@ -33,41 +40,42 @@ class LogitsProcessor(nn.Module):
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, (None, None)
return last_logits, (None, None, None)
else:
assert input_metadata.forward_mode != ForwardMode.DECODE
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
# When logprob is requested, compute the logits for all tokens.
logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
logits = tensor_model_parallel_all_gather(logits)
logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
if input_metadata.forward_mode == ForwardMode.DECODE:
last_logits = logits
last_logprobs = all_logprobs
prefill_logprobs = normalized_logprobs = None
else:
# Compute the logprobs for the last token of each request.
last_logits = logits[last_index]
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
last_logits = logits[last_index]
return last_logits, (logprobs, normalized_logprobs)
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
if __name__ == "__main__":
......
......@@ -99,3 +99,7 @@ class BatchStrOut:
@dataclass
class FlushCacheReq:
pass
@dataclass
class DetokenizeReqInput:
input_ids: List[int]
......@@ -48,6 +48,7 @@ class Req:
self.last_node = None
self.logprob = None
self.token_logprob = None
self.normalized_logprob = None
# For constrained decoding
......
......@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service):
self.model_config.vocab_size, self.int_token_logit_bias
)
logprobs = None
if batch.extend_num_tokens != 0:
# Forward
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_logprob
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
)
# print("extend logits", logits)
if logprobs is not None:
logprobs = logprobs.cpu().tolist()
if prefill_logprobs is not None:
logprobs = prefill_logprobs.cpu().tolist()
normalized_logprobs = normalized_logprobs.cpu().tolist()
next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
logprobs = normalized_logprobs = None
logits = logprobs = normalized_logprobs = last_logprobs = None
# Check finish condition
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
# Check finish condition
pt = 0
for i, req in enumerate(reqs):
req.output_ids = [next_token_ids[i]]
......@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service):
if logprobs is not None:
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
req.normalized_logprob = normalized_logprobs[i]
token_ids = req.input_ids + [next_token_ids[i]]
token_logprobs = [None] + req.logprob + [last_logprobs[i]]
req.token_logprob = list(zip(token_ids, token_logprobs))
pt += req.extend_input_len
self.handle_finished_requests(batch)
......@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service):
batch.prepare_for_decode()
# Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
logits, (_, _, last_logprobs) = self.model_runner.forward(
batch,
ForwardMode.DECODE,
batch.return_logprob,
)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
# Check finish condition
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
for i in range(len(reqs)):
reqs[i].output_ids.append(next_token_ids[i])
reqs[i].check_finished()
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
# Check finish condition
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
req.output_ids.append(next_tok_id)
req.check_finished()
if last_logprobs is not None:
req.token_logprob.append((next_tok_id, last_logprobs[i]))
self.handle_finished_requests(batch)
......@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service):
}
if req.return_logprob:
meta_info["prompt_logprob"] = req.logprob
meta_info["token_logprob"] = req.token_logprob
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
output_meta_info.append(meta_info)
output_finished.append(req.finished)
......
......@@ -397,6 +397,7 @@ class ModelRunner:
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
return_logprob,
):
input_metadata = InputMetadata.create(
self,
......@@ -409,10 +410,9 @@ class ModelRunner:
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
return_logprob=return_logprob,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
0
]
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_extend_multi_modal(
......@@ -460,8 +460,8 @@ class ModelRunner:
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
}
kwargs["return_logprob"] = return_logprob
return self.forward_extend_multi_modal(**kwargs)
else:
kwargs = {
......@@ -471,6 +471,7 @@ class ModelRunner:
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
}
if forward_mode == ForwardMode.DECODE:
......@@ -478,10 +479,8 @@ class ModelRunner:
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND:
kwargs["return_logprob"] = return_logprob
return self.forward_extend(**kwargs)
elif forward_mode == ForwardMode.PREFILL:
kwargs["return_logprob"] = return_logprob
return self.forward_prefill(**kwargs)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")
......@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
DetokenizeReqInput,
FlushCacheReq,
GenerateReqInput,
TokenizedGenerateReqInput,
......@@ -234,6 +235,10 @@ class TokenizerManager:
yield output_list
async def detokenize(self, obj: DetokenizeReqInput):
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
async def flush_cache(self):
flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req)
......
......@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
)
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
from sglang.srt.managers.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
......@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
LogProbs,
UsageInfo,
)
from sglang.srt.managers.router.manager import start_router_process
......@@ -97,6 +98,23 @@ async def stream_generator(obj):
yield out
async def make_openai_style_logprobs(token_logprobs):
ret_logprobs = LogProbs()
# Detokenize
token_ids = [tid for tid, _ in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(token_logprob)
# Not supported yet.
ret_logprobs.top_logprobs.append({})
ret_logprobs.text_offset.append(-1)
return ret_logprobs
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
......@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
return_logprob=request.logprobs is not None,
stream=request.stream,
)
adapted_request.post_init()
......@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
async def gnerate_stream_resp():
stream_buffer = ""
n_prev_token = 0
async for content in stream_generator(adapted_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
else:
# Skip prompt tokens if echo is disabled.
n_prev_token = prompt_tokens
if request.logprobs is not None:
logprobs = await make_openai_style_logprobs(
content["meta_info"]["token_logprob"][n_prev_token:]
)
n_prev_token = len(content["meta_info"]["token_logprob"])
else:
logprobs = None
delta = text[len(stream_buffer) :]
stream_buffer = text
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=None,
logprobs=logprobs,
finish_reason=None,
)
chunk = CompletionStreamResponse(
......@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
# Non-streaming response.
ret = await generate_request(adapted_request)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"]
token_logprob_pos = prompt_tokens
if request.echo:
token_logprob_pos = 0
text = request.prompt + text
else:
token_logprob_pos = prompt_tokens
logprobs = (
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
if request.logprobs is not None
else None
)
choice_data = CompletionResponseChoice(
index=0,
text=ret["text"],
logprobs=None,
text=text,
logprobs=logprobs,
finish_reason=None, # TODO(comaniac): Add finish reason.
)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
......@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
if not isinstance(m.content, str):
raise HTTPException(
status_code=503,
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
detail="Structured content requests not supported with "
"HuggingFace Chat Templates. "
"Make sure the server specifies a sglang chat template.",
)
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
......
......@@ -9,18 +9,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
"""
import argparse
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
def test_decode(url, return_logprob):
response = requests.post(
url + "/generate",
json={
......@@ -29,8 +21,19 @@ if __name__ == "__main__":
"temperature": 0,
"max_new_tokens": 32,
},
# "return_logprob": True,
# "logprob_start_len": 0,
"return_logprob": return_logprob,
"logprob_start_len": 0,
},
)
print(response.json())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode(url, False)
test_decode(url, True)
......@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
import argparse
import json
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
def test_decode_stream(url, return_logprob):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 512,
"max_new_tokens": 128,
},
"stream": True,
"return_logprob": return_logprob,
},
stream=True,
)
......@@ -41,7 +34,29 @@ if __name__ == "__main__":
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
if return_logprob:
assert data["meta_info"]["prompt_logprob"] is not None
assert data["meta_info"]["token_logprob"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
if prev == 0: # Skip prompt logprobs
prev = data["meta_info"]["prompt_tokens"]
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
print(f"{token_txt}\t{logprob}", flush=True)
prev = len(data["meta_info"]["token_logprob"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode_stream(url, False)
test_decode_stream(url, True)
......@@ -18,15 +18,26 @@ import argparse
import openai
def test_completion(args):
def test_completion(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
)
text = response.choices[0].text
print(response.choices[0].text)
if echo:
assert text.startswith("The capital of France is")
if logprobs:
assert response.choices[0].logprobs
if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
......@@ -34,7 +45,7 @@ def test_completion(args):
assert response.usage.total_tokens > 0
def test_completion_stream(args):
def test_completion_stream(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
......@@ -42,9 +53,23 @@ def test_completion_stream(args):
temperature=0,
max_tokens=32,
stream=True,
echo=echo,
logprobs=logprobs,
)
first = True
for r in response:
print(r.choices[0].text, end="", flush=True)
if first:
if echo:
assert r.choices[0].text.startswith("The capital of France is")
first = False
if logprobs:
print(
f"{r.choices[0].text:12s}\t"
f"{r.choices[0].logprobs.token_logprobs}",
flush=True
)
else:
print(r.choices[0].text, end="", flush=True)
assert r.id
assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0
......@@ -135,8 +160,14 @@ if __name__ == "__main__":
)
args = parser.parse_args()
test_completion(args)
test_completion_stream(args)
test_completion(args, echo=False, logprobs=False)
test_completion(args, echo=True, logprobs=False)
test_completion(args, echo=False, logprobs=True)
test_completion(args, echo=True, logprobs=True)
test_completion_stream(args, echo=False, logprobs=False)
test_completion_stream(args, echo=True, logprobs=False)
test_completion_stream(args, echo=False, logprobs=True)
test_completion_stream(args, echo=True, logprobs=True)
test_chat_completion(args)
test_chat_completion_stream(args)
if args.test_image:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment