Unverified Commit 303ef888 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up logits processor (#558)

parent 92cb93f3
"""Logits processing.""" """Logits processing."""
import dataclasses
from typing import List
import torch import torch
from torch import nn from torch import nn
from vllm.distributed import ( from vllm.distributed import (
...@@ -10,6 +13,24 @@ from vllm.distributed import ( ...@@ -10,6 +13,24 @@ from vllm.distributed import (
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
@dataclasses.dataclass
class LogitProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs: torch.Tensor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs: List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs: List
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module): ...@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
return normalized_prompt_logprobs return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata): def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
# TODO: vectorize the code below
if input_metadata.forward_mode == ForwardMode.DECODE: if input_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = [] decode_top_logprobs = []
for i in range(all_logprobs.shape[0]): for i in range(all_logprobs.shape[0]):
...@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module): ...@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
else: else:
prefill_top_logprobs, decode_top_logprobs = [], [] prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0 pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist() extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu): for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0: if extend_seq_len == 0:
...@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module): ...@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
return prefill_top_logprobs, decode_top_logprobs return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
# Get last index for next token prediction, except for DECODE mode. # Get the last hidden states and last logits for the next token prediction
last_index = None if input_metadata.forward_mode == ForwardMode.DECODE:
if input_metadata.forward_mode != ForwardMode.DECODE: last_index = None
last_hidden = hidden_states
else:
last_index = ( last_index = (
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long) torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1 - 1
) )
# Get the last hidden states and last logits
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_hidden = hidden_states[last_index] last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T) last_logits = torch.matmul(last_hidden, weight.T)
...@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module): ...@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested # Return only last_logits if logprob is not requested
if not input_metadata.return_logprob: if not input_metadata.return_logprob:
hidden_states = None return LogitProcessorOutput(
return last_logits, (None, None, None, None, None) next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
)
else: else:
# When logprob is requested, compute the logits for all tokens. # When logprob is requested, compute the logits for all tokens.
if input_metadata.forward_mode == ForwardMode.DECODE: if input_metadata.forward_mode == ForwardMode.DECODE:
...@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module): ...@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
del all_logits del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob: if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
...@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module): ...@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs = decode_top_logprobs = None prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE: if input_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = all_logprobs return LogitProcessorOutput(
return last_logits, ( next_token_logits=last_logits,
None, next_token_logprobs=all_logprobs,
None, normalized_prompt_logprobs=None,
None, prefill_token_logprobs=None,
decode_top_logprobs, prefill_top_logprobs=None,
last_logprobs, decode_top_logprobs=decode_top_logprobs,
) )
else: else:
# Compute the logprobs for the last token of each request.
last_logprobs = all_logprobs[last_index] last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens. # Compute the logprobs and normalized logprobs for the prefill tokens.
...@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module): ...@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, input_metadata prefill_token_logprobs, input_metadata
) )
return last_logits, (
prefill_token_logprobs, return LogitProcessorOutput(
normalized_prompt_logprobs, next_token_logits=last_logits,
prefill_top_logprobs, next_token_logprobs=last_logprobs,
decode_top_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs,
last_logprobs, prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_top_logprobs=decode_top_logprobs,
) )
......
...@@ -441,33 +441,25 @@ class ModelTpServer: ...@@ -441,33 +441,25 @@ class ModelTpServer:
self.model_config.vocab_size, self.int_token_logit_bias self.model_config.vocab_size, self.int_token_logit_bias
) )
# Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward output = self.model_runner.forward(batch, ForwardMode.EXTEND)
logits, ( next_token_ids, _ = batch.sample(output.next_token_logits)
prefill_token_logprobs,
normalized_prompt_logprobs, # Move logprobs to cpu
prefill_top_logprobs, if output.next_token_logprobs is not None:
decode_top_logprobs, output.next_token_logprobs = output.next_token_logprobs[
last_logprobs, torch.arange(len(next_token_ids), device=next_token_ids.device),
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
if prefill_token_logprobs is not None:
prefill_token_logprobs = prefill_token_logprobs.tolist()
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
next_token_ids, _ = batch.sample(logits)
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids, next_token_ids,
].tolist() ].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish condition # Check finish conditions
pt = 0 pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
...@@ -475,57 +467,59 @@ class ModelTpServer: ...@@ -475,57 +467,59 @@ class ModelTpServer:
req.check_finished() req.check_finished()
if req.return_logprob: if req.return_logprob:
if req.normalized_prompt_logprob is None: self.add_logprob_return_values(i, req, pt, next_token_ids, output)
req.normalized_prompt_logprob = normalized_prompt_logprobs[i] pt += req.extend_input_len
if req.prefill_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
zip(
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
list(
zip(
prefill_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
+ req.extend_input_len
- 1
],
req.input_ids[-req.last_update_decode_tokens + 1 :],
)
)
)
req.decode_token_logprobs.append( self.handle_finished_requests(batch)
(last_token_logprobs[i], next_token_ids[i])
)
if req.top_logprobs_num > 0: def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
if req.prefill_top_logprobs is None: if req.normalized_prompt_logprob is None:
req.prefill_top_logprobs = prefill_top_logprobs[i] req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
if req.last_update_decode_tokens != 0: if req.prefill_token_logprobs is None:
req.decode_top_logprobs.extend( # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :] req.prefill_token_logprobs = list(
zip(
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
list(
zip(
output.prefill_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
+ req.extend_input_len
- 1
],
req.input_ids[-req.last_update_decode_tokens + 1 :],
) )
req.decode_top_logprobs.append(decode_top_logprobs[i]) )
)
pt += req.extend_input_len req.decode_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
self.handle_finished_requests(batch) if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
def cache_filled_batch(self, batch: Batch): def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
...@@ -540,7 +534,7 @@ class ModelTpServer: ...@@ -540,7 +534,7 @@ class ModelTpServer:
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
def forward_decode_batch(self, batch: Batch): def forward_decode_batch(self, batch: Batch):
# check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(): if not batch.check_decode_mem():
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0) self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
...@@ -559,9 +553,8 @@ class ModelTpServer: ...@@ -559,9 +553,8 @@ class ModelTpServer:
) )
if not self.disable_regex_jump_forward: if not self.disable_regex_jump_forward:
# check for jump-forward # Check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
self.forward_queue.extend(jump_forward_reqs) self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty(): if batch.is_empty():
return return
...@@ -570,23 +563,19 @@ class ModelTpServer: ...@@ -570,23 +563,19 @@ class ModelTpServer:
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward # Forward and sample the next tokens
logits, ( output = self.model_runner.forward(batch, ForwardMode.DECODE)
_, next_token_ids, _ = batch.sample(output.next_token_logits)
_,
_,
decode_top_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.tolist()
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. # Move logprobs to cpu
if last_logprobs is not None: if output.next_token_logprobs is not None:
new_token_logprobs = last_logprobs[ next_token_logprobs = output.next_token_logprobs[
torch.arange(len(batch.reqs)), next_token_ids torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist() ].tolist()
next_token_ids = next_token_ids.tolist()
# Check finish condition # Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
...@@ -594,10 +583,9 @@ class ModelTpServer: ...@@ -594,10 +583,9 @@ class ModelTpServer:
req.check_finished() req.check_finished()
if req.return_logprob: if req.return_logprob:
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id)) req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
if req.top_logprobs_num > 0:
if req.top_logprobs_num > 0: req.decode_top_logprobs.append(output.decode_top_logprobs[i])
req.decode_top_logprobs.append(decode_top_logprobs[i])
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
......
...@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
try: try:
requests.get(url + "/get_model_info", timeout=5, headers=headers) requests.get(url + "/get_model_info", timeout=5, headers=headers)
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException:
pass pass
# Send a warmup request # Send a warmup request
...@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
"text": "The capital city of France is", "text": "The capital city of France is",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 8,
}, },
}, },
headers=headers, headers=headers,
timeout=600, timeout=600,
) )
assert res.status_code == 200 assert res.status_code == 200
except Exception: except Exception as e:
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(get_exception_traceback()) pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}") print(f"Initialization failed. warmup error: {e}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment