Unverified Commit 0a409bd4 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix return_log_probs with cuda graph (#775)

parent e4db4e5b
"""Logits processing."""
import dataclasses
from typing import List, Union
from typing import List, Optional, Union
import torch
from torch import nn
......@@ -34,11 +34,11 @@ class LogitProcessorOutput:
@dataclasses.dataclass
class LogitsMetadata:
forward_mode: ForwardMode
return_logprob: bool
return_logprob: bool = False
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
top_logprobs_nums: List[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
top_logprobs_nums: Optional[List[int]] = None
@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
......@@ -79,7 +79,8 @@ class LogitsProcessor(nn.Module):
return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
@staticmethod
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
......@@ -156,36 +157,48 @@ class LogitsProcessor(nn.Module):
else:
# When logprob is requested, compute the logits for all tokens.
if logits_metadata.forward_mode == ForwardMode.DECODE:
all_logits = last_logits
else:
all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
all_logprobs = all_logits
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, logits_metadata
# Get the logprob of top-k tokens
return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
else:
prefill_top_logprobs = decode_top_logprobs = None
if return_top_logprob:
decode_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
decode_top_logprobs = None
if logits_metadata.forward_mode == ForwardMode.DECODE:
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=all_logprobs,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
)
else:
all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
all_logprobs = all_logits
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
......
......@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.layers.logits_processor import (
LogitProcessorOutput,
LogitsMetadata,
LogitsProcessor,
)
from sglang.srt.managers.controller.infer_batch import (
Batch,
ForwardMode,
......@@ -185,7 +189,6 @@ class CudaGraphRunner:
def replay(self, batch: Batch):
assert batch.out_cache_loc is not None
assert not batch.return_logprob
raw_bs = len(batch.reqs)
# Pad
......@@ -218,23 +221,29 @@ class CudaGraphRunner:
output = self.output_buffers[bs]
# Unpad
if bs == raw_bs:
return output
else:
if bs != raw_bs:
output = LogitProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=(
output.next_token_logprobs[:raw_bs]
if output.next_token_logprobs is not None
else None
),
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=(
output.decode_top_logprobs[:raw_bs]
if output.decode_top_logprobs is not None
else None
),
decode_top_logprobs=None,
)
# Extract logprobs
if batch.return_logprob:
output.next_token_logprobs = torch.nn.functional.log_softmax(
output.next_token_logits, dim=-1
)
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums,
)
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata
)[1]
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment