Unverified Commit 0a409bd4 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix return_log_probs with cuda graph (#775)

parent e4db4e5b
"""Logits processing.""" """Logits processing."""
import dataclasses import dataclasses
from typing import List, Union from typing import List, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -34,11 +34,11 @@ class LogitProcessorOutput: ...@@ -34,11 +34,11 @@ class LogitProcessorOutput:
@dataclasses.dataclass @dataclasses.dataclass
class LogitsMetadata: class LogitsMetadata:
forward_mode: ForwardMode forward_mode: ForwardMode
return_logprob: bool return_logprob: bool = False
extend_seq_lens: torch.Tensor = None extend_seq_lens: Optional[torch.Tensor] = None
extend_start_loc: torch.Tensor = None extend_start_loc: Optional[torch.Tensor] = None
top_logprobs_nums: List[int] = None top_logprobs_nums: Optional[List[int]] = None
@classmethod @classmethod
def from_input_metadata(cls, input_metadata: InputMetadata): def from_input_metadata(cls, input_metadata: InputMetadata):
...@@ -79,7 +79,8 @@ class LogitsProcessor(nn.Module): ...@@ -79,7 +79,8 @@ class LogitsProcessor(nn.Module):
return normalized_prompt_logprobs return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata): @staticmethod
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below # TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = [] decode_top_logprobs = []
...@@ -156,36 +157,48 @@ class LogitsProcessor(nn.Module): ...@@ -156,36 +157,48 @@ class LogitsProcessor(nn.Module):
else: else:
# When logprob is requested, compute the logits for all tokens. # When logprob is requested, compute the logits for all tokens.
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
all_logits = last_logits last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
else:
all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
all_logprobs = all_logits # Get the logprob of top-k tokens
del all_logits return_top_logprob = any(
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) x > 0 for x in logits_metadata.top_logprobs_nums
# Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, logits_metadata
) )
else: if return_top_logprob:
prefill_top_logprobs = decode_top_logprobs = None decode_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
decode_top_logprobs = None
if logits_metadata.forward_mode == ForwardMode.DECODE:
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=all_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
prefill_token_logprobs=None, prefill_token_logprobs=None,
prefill_top_logprobs=None, prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs, decode_top_logprobs=decode_top_logprobs,
) )
else: else:
all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
all_logprobs = all_logits
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
last_logprobs = all_logprobs[last_index] last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens. # Compute the logprobs and normalized logprobs for the prefill tokens.
......
...@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels ...@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import (
LogitProcessorOutput,
LogitsMetadata,
LogitsProcessor,
)
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.controller.infer_batch import (
Batch, Batch,
ForwardMode, ForwardMode,
...@@ -185,7 +189,6 @@ class CudaGraphRunner: ...@@ -185,7 +189,6 @@ class CudaGraphRunner:
def replay(self, batch: Batch): def replay(self, batch: Batch):
assert batch.out_cache_loc is not None assert batch.out_cache_loc is not None
assert not batch.return_logprob
raw_bs = len(batch.reqs) raw_bs = len(batch.reqs)
# Pad # Pad
...@@ -218,23 +221,29 @@ class CudaGraphRunner: ...@@ -218,23 +221,29 @@ class CudaGraphRunner:
output = self.output_buffers[bs] output = self.output_buffers[bs]
# Unpad # Unpad
if bs == raw_bs: if bs != raw_bs:
return output
else:
output = LogitProcessorOutput( output = LogitProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs], next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=( next_token_logprobs=None,
output.next_token_logprobs[:raw_bs]
if output.next_token_logprobs is not None
else None
),
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
prefill_token_logprobs=None, prefill_token_logprobs=None,
prefill_top_logprobs=None, prefill_top_logprobs=None,
decode_top_logprobs=( decode_top_logprobs=None,
output.decode_top_logprobs[:raw_bs]
if output.decode_top_logprobs is not None
else None
),
) )
# Extract logprobs
if batch.return_logprob:
output.next_token_logprobs = torch.nn.functional.log_softmax(
output.next_token_logits, dim=-1
)
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums,
)
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata
)[1]
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment