Unverified Commit 23196d52 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify logits processor (#2974)


Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent 93b77c8e
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Logits processing.""" """Logits processing."""
import dataclasses import dataclasses
import logging
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
...@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
) )
logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class LogitsProcessorOutput: class LogitsProcessorOutput:
...@@ -136,50 +139,61 @@ class LogitsProcessor(nn.Module): ...@@ -136,50 +139,61 @@ class LogitsProcessor(nn.Module):
logits_metadata.forward_mode.is_decode_or_idle() logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify() or logits_metadata.forward_mode.is_target_verify()
): ):
last_index = None pruned_states = hidden_states
last_hidden = hidden_states sample_indices = None
else: elif (
logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob
):
# Prefill without input logprobs.
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index] pruned_states = hidden_states[last_index]
sample_indices = None
else:
# Slice the requested tokens to compute logprob
sample_index_pt = -1
sample_indices = []
pt, pruned_states, pruned_input_ids = 0, [], []
for start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt)
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len
pruned_states = torch.cat(pruned_states)
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
# Compute logits
last_logits = self._get_logits(last_hidden, lm_head)
if ( if (
not logits_metadata.extend_return_logprob not logits_metadata.extend_return_logprob
or logits_metadata.capture_hidden_mode.need_capture() or logits_metadata.capture_hidden_mode.need_capture()
): ):
# Decode mode or extend mode without return_logprob. # Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=sampled_logits,
hidden_states=( hidden_states=(
hidden_states hidden_states
if logits_metadata.capture_hidden_mode.is_full() if logits_metadata.capture_hidden_mode.is_full()
else ( else (
last_hidden pruned_states
if logits_metadata.capture_hidden_mode.is_last() if logits_metadata.capture_hidden_mode.is_last()
else None else None
) )
), ),
) )
else: else:
# Slice the requested tokens to compute logprob input_logprobs = logits
pt, pruned_states, pruned_input_ids = 0, [], [] del hidden_states, logits
for start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len
# Compute the logits of all required tokens
pruned_states = torch.cat(pruned_states)
del hidden_states
input_token_logits = self._get_logits(pruned_states, lm_head)
del pruned_states
# Normalize the logprob w/o temperature, top-p # Normalize the logprob w/o temperature, top-p
input_logprobs = input_token_logits
input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata input_logprobs, logits_metadata
) )
...@@ -194,17 +208,17 @@ class LogitsProcessor(nn.Module): ...@@ -194,17 +208,17 @@ class LogitsProcessor(nn.Module):
input_top_logprobs_val = input_top_logprobs_idx = None input_top_logprobs_val = input_top_logprobs_idx = None
input_token_logprobs = input_logprobs[ input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device="cuda"), torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
torch.cat( torch.cat(
[ [
torch.cat(pruned_input_ids)[1:], torch.cat(pruned_input_ids)[1:],
torch.tensor([0], device="cuda"), torch.tensor([0], device=input_logprobs.device),
] ]
), ),
] ]
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=sampled_logits,
input_token_logprobs=input_token_logprobs, input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx, input_top_logprobs_idx=input_top_logprobs_idx,
...@@ -214,8 +228,11 @@ class LogitsProcessor(nn.Module): ...@@ -214,8 +228,11 @@ class LogitsProcessor(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding, lm_head: VocabParallelEmbedding,
logits_metadata: LogitsMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Get logits from hidden_states."""
if hasattr(lm_head, "weight"): if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T) logits = torch.matmul(hidden_states, lm_head.weight.T)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment