Unverified Commit f25b76c0 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

add `LogitsMetadata` (#604)

parent f4e885b7
...@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio): ...@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
) )
for i in redirect_indices: for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines)) target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = ( lines[
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." i
) ] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
redirects[i] = target_idx redirects[i] = target_idx
# Build links and find sources # Build links and find sources
......
"""Logits processing.""" """Logits processing."""
import dataclasses import dataclasses
from typing import List from typing import List, Union
import torch import torch
from torch import nn from torch import nn
...@@ -31,6 +31,27 @@ class LogitProcessorOutput: ...@@ -31,6 +31,27 @@ class LogitProcessorOutput:
decode_top_logprobs: List decode_top_logprobs: List
@dataclasses.dataclass
class LogitsMetadata:
forward_mode: ForwardMode
extend_seq_lens: torch.Tensor
extend_start_loc: torch.Tensor
# For logprobs
return_logprob: bool
top_logprobs_nums: List[int]
@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
return cls(
forward_mode=input_metadata.forward_mode,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_start_loc=input_metadata.extend_start_loc,
return_logprob=input_metadata.return_logprob,
top_logprobs_nums=input_metadata.top_logprobs_nums,
)
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module): ...@@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def _get_normalized_prompt_logprobs( def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, input_metadata: InputMetadata self, prefill_token_logprobs, logits_metadata: LogitsMetadata
): ):
logprobs_cumsum = torch.cumsum( logprobs_cumsum = torch.cumsum(
prefill_token_logprobs, dim=0, dtype=torch.float32 prefill_token_logprobs, dim=0, dtype=torch.float32
) )
start = input_metadata.extend_start_loc.clone() start = logits_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2 end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
sum_logp = ( sum_logp = (
...@@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module): ...@@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module):
+ prefill_token_logprobs[start] + prefill_token_logprobs[start]
) )
normalized_prompt_logprobs = sum_logp / ( normalized_prompt_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1) (logits_metadata.extend_seq_lens - 1).clamp(min=1)
) )
return normalized_prompt_logprobs return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata): def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below # TODO: vectorize the code below
if input_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = [] decode_top_logprobs = []
for i in range(all_logprobs.shape[0]): for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k) t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist() v_cpu = t.values.tolist()
p_cpu = t.indices.tolist() p_cpu = t.indices.tolist()
...@@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module): ...@@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module):
else: else:
prefill_top_logprobs, decode_top_logprobs = [], [] prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0 pt = 0
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist() extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu): for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0: if extend_seq_len == 0:
prefill_top_logprobs.append([]) prefill_top_logprobs.append([])
decode_top_logprobs.append([]) decode_top_logprobs.append([])
continue continue
k = input_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k) t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist() vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist() ps_cpu = t.indices.tolist()
...@@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module): ...@@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module):
return prefill_top_logprobs, decode_top_logprobs return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): def forward(
self,
input_ids,
hidden_states,
weight,
logits_metadata: Union[LogitsMetadata, InputMetadata],
):
if isinstance(logits_metadata, InputMetadata):
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction # Get the last hidden states and last logits for the next token prediction
if input_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
last_index = None last_index = None
last_hidden = hidden_states last_hidden = hidden_states
else: else:
last_index = ( last_index = (
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long) torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1 - 1
) )
last_hidden = hidden_states[last_index] last_hidden = hidden_states[last_index]
...@@ -114,7 +145,7 @@ class LogitsProcessor(nn.Module): ...@@ -114,7 +145,7 @@ class LogitsProcessor(nn.Module):
last_logits *= self.config.final_logit_softcapping last_logits *= self.config.final_logit_softcapping
# Return only last_logits if logprob is not requested # Return only last_logits if logprob is not requested
if not input_metadata.return_logprob: if not logits_metadata.return_logprob:
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=None, next_token_logprobs=None,
...@@ -125,7 +156,7 @@ class LogitsProcessor(nn.Module): ...@@ -125,7 +156,7 @@ class LogitsProcessor(nn.Module):
) )
else: else:
# When logprob is requested, compute the logits for all tokens. # When logprob is requested, compute the logits for all tokens.
if input_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
all_logits = last_logits all_logits = last_logits
else: else:
all_logits = torch.matmul(hidden_states, weight.T) all_logits = torch.matmul(hidden_states, weight.T)
...@@ -138,15 +169,15 @@ class LogitsProcessor(nn.Module): ...@@ -138,15 +169,15 @@ class LogitsProcessor(nn.Module):
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens # Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
if return_top_logprob: if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata all_logprobs, logits_metadata
) )
else: else:
prefill_top_logprobs = decode_top_logprobs = None prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=all_logprobs, next_token_logprobs=all_logprobs,
...@@ -166,7 +197,7 @@ class LogitsProcessor(nn.Module): ...@@ -166,7 +197,7 @@ class LogitsProcessor(nn.Module):
] ]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, input_metadata prefill_token_logprobs, logits_metadata
) )
return LogitProcessorOutput( return LogitProcessorOutput(
......
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
import numpy as np import numpy as np
import torch import torch
from torch import nn
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd
......
...@@ -334,15 +334,15 @@ class TokenizerManager: ...@@ -334,15 +334,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
) )
if top_logprobs_num > 0: if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = ( ret["meta_info"][
self.detokenize_top_logprobs_tokens( "prefill_top_logprobs"
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs ] = self.detokenize_top_logprobs_tokens(
) ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
) )
ret["meta_info"]["decode_top_logprobs"] = ( ret["meta_info"][
self.detokenize_top_logprobs_tokens( "decode_top_logprobs"
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs ] = self.detokenize_top_logprobs_tokens(
) ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
) )
return ret return ret
......
...@@ -81,7 +81,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding ...@@ -81,7 +81,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
class GemmaRotaryEmbedding(RotaryEmbedding): class GemmaRotaryEmbedding(RotaryEmbedding):
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / ( inv_freq = 1.0 / (
...@@ -95,7 +94,6 @@ class GemmaRotaryEmbedding(RotaryEmbedding): ...@@ -95,7 +94,6 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
class Gemma2MLP(nn.Module): class Gemma2MLP(nn.Module):
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
...@@ -127,7 +125,6 @@ class Gemma2MLP(nn.Module): ...@@ -127,7 +125,6 @@ class Gemma2MLP(nn.Module):
class Gemma2Attention(nn.Module): class Gemma2Attention(nn.Module):
def __init__( def __init__(
self, self,
layer_idx: int, layer_idx: int,
...@@ -218,7 +215,6 @@ class Gemma2Attention(nn.Module): ...@@ -218,7 +215,6 @@ class Gemma2Attention(nn.Module):
class Gemma2DecoderLayer(nn.Module): class Gemma2DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
layer_idx: int, layer_idx: int,
...@@ -287,7 +283,6 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -287,7 +283,6 @@ class Gemma2DecoderLayer(nn.Module):
class Gemma2Model(nn.Module): class Gemma2Model(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
......
...@@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module): ...@@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module):
if rope_scaling is not None and getattr( if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None config, "original_max_position_embeddings", None
): ):
rope_scaling["original_max_position_embeddings"] = ( rope_scaling[
config.original_max_position_embeddings "original_max_position_embeddings"
) ] = config.original_max_position_embeddings
rope_is_neox_style = getattr(config, "rope_is_neox_style", True) rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
......
...@@ -459,6 +459,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int): ...@@ -459,6 +459,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
""" """
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment