"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "58f10679e1850fdc86046057c23bac5193156de9"
Unverified Commit 303ef888 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up logits processor (#558)

parent 92cb93f3
"""Logits processing."""
import dataclasses
from typing import List
import torch
from torch import nn
from vllm.distributed import (
......@@ -10,6 +13,24 @@ from vllm.distributed import (
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
@dataclasses.dataclass
class LogitProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs: torch.Tensor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs: List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs: List
class LogitsProcessor(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
# TODO: vectorize the code below
if input_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
for i in range(all_logprobs.shape[0]):
......@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
else:
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
......@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
# Get last index for next token prediction, except for DECODE mode.
last_index = None
if input_metadata.forward_mode != ForwardMode.DECODE:
# Get the last hidden states and last logits for the next token prediction
if input_metadata.forward_mode == ForwardMode.DECODE:
last_index = None
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1
)
# Get the last hidden states and last logits
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T)
......@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested
if not input_metadata.return_logprob:
hidden_states = None
return last_logits, (None, None, None, None, None)
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
)
else:
# When logprob is requested, compute the logits for all tokens.
if input_metadata.forward_mode == ForwardMode.DECODE:
......@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
......@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = all_logprobs
return last_logits, (
None,
None,
None,
decode_top_logprobs,
last_logprobs,
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=all_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
)
else:
# Compute the logprobs for the last token of each request.
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
......@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, input_metadata
)
return last_logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
last_logprobs,
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_top_logprobs=decode_top_logprobs,
)
......
......@@ -441,33 +441,25 @@ class ModelTpServer:
self.model_config.vocab_size, self.int_token_logit_bias
)
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
# Forward
logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
if prefill_token_logprobs is not None:
prefill_token_logprobs = prefill_token_logprobs.tolist()
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
next_token_ids, _ = batch.sample(logits)
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish condition
# Check finish conditions
pt = 0
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
......@@ -475,57 +467,59 @@ class ModelTpServer:
req.check_finished()
if req.return_logprob:
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
if req.prefill_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
zip(
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
list(
zip(
prefill_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
+ req.extend_input_len
- 1
],
req.input_ids[-req.last_update_decode_tokens + 1 :],
)
)
)
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len
req.decode_token_logprobs.append(
(last_token_logprobs[i], next_token_ids[i])
)
self.handle_finished_requests(batch)
if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
if req.prefill_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
zip(
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
list(
zip(
output.prefill_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
+ req.extend_input_len
- 1
],
req.input_ids[-req.last_update_decode_tokens + 1 :],
)
req.decode_top_logprobs.append(decode_top_logprobs[i])
)
)
pt += req.extend_input_len
req.decode_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
self.handle_finished_requests(batch)
if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
......@@ -540,7 +534,7 @@ class ModelTpServer:
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
def forward_decode_batch(self, batch: Batch):
# check if decode out of memory
# Check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
......@@ -559,9 +553,8 @@ class ModelTpServer:
)
if not self.disable_regex_jump_forward:
# check for jump-forward
# Check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty():
return
......@@ -570,23 +563,19 @@ class ModelTpServer:
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
# Forward
logits, (
_,
_,
_,
decode_top_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.tolist()
# Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(output.next_token_logits)
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
new_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs)), next_token_ids
# Move logprobs to cpu
if output.next_token_logprobs is not None:
next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
......@@ -594,10 +583,9 @@ class ModelTpServer:
req.check_finished()
if req.return_logprob:
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(decode_top_logprobs[i])
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
self.handle_finished_requests(batch)
......
......@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break
except requests.exceptions.RequestException as e:
except requests.exceptions.RequestException:
pass
# Send a warmup request
......@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"max_new_tokens": 8,
},
},
headers=headers,
timeout=600,
)
assert res.status_code == 200
except Exception:
except Exception as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment