"vscode:/vscode.git/clone" did not exist on "5df02fc171ba2bc1e559874cdd4f88661a7dd1d6"
Unverified Commit 23196d52 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify logits processor (#2974)


Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent 93b77c8e
......@@ -14,6 +14,7 @@
"""Logits processing."""
import dataclasses
import logging
from typing import List, Optional, Union
import torch
......@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
)
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class LogitsProcessorOutput:
......@@ -136,50 +139,61 @@ class LogitsProcessor(nn.Module):
logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify()
):
last_index = None
last_hidden = hidden_states
else:
pruned_states = hidden_states
sample_indices = None
elif (
logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob
):
# Prefill without input logprobs.
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index]
pruned_states = hidden_states[last_index]
sample_indices = None
else:
# Slice the requested tokens to compute logprob
sample_index_pt = -1
sample_indices = []
pt, pruned_states, pruned_input_ids = 0, [], []
for start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt)
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len
pruned_states = torch.cat(pruned_states)
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
# Compute logits
last_logits = self._get_logits(last_hidden, lm_head)
if (
not logits_metadata.extend_return_logprob
or logits_metadata.capture_hidden_mode.need_capture()
):
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logits=sampled_logits,
hidden_states=(
hidden_states
if logits_metadata.capture_hidden_mode.is_full()
else (
last_hidden
pruned_states
if logits_metadata.capture_hidden_mode.is_last()
else None
)
),
)
else:
# Slice the requested tokens to compute logprob
pt, pruned_states, pruned_input_ids = 0, [], []
for start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len
# Compute the logits of all required tokens
pruned_states = torch.cat(pruned_states)
del hidden_states
input_token_logits = self._get_logits(pruned_states, lm_head)
del pruned_states
input_logprobs = logits
del hidden_states, logits
# Normalize the logprob w/o temperature, top-p
input_logprobs = input_token_logits
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
)
......@@ -194,17 +208,17 @@ class LogitsProcessor(nn.Module):
input_top_logprobs_val = input_top_logprobs_idx = None
input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device="cuda"),
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
torch.cat(
[
torch.cat(pruned_input_ids)[1:],
torch.tensor([0], device="cuda"),
torch.tensor([0], device=input_logprobs.device),
]
),
]
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logits=sampled_logits,
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
......@@ -214,8 +228,11 @@ class LogitsProcessor(nn.Module):
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
logits_metadata: LogitsMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get logits from hidden_states."""
if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment