Unverified Commit 381dd57b authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Sampler cudagraph (#1253)

parent 8153168c
...@@ -200,16 +200,16 @@ def extend(reqs, model_runner): ...@@ -200,16 +200,16 @@ def extend(reqs, model_runner):
tree_cache=None, tree_cache=None,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size) batch.prepare_for_extend(model_runner.model_config.vocab_size)
output = model_runner.forward(batch, ForwardMode.EXTEND) sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits) next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids.cpu().numpy()) batch.prepare_for_decode(input_token_ids)
output = model_runner.forward(batch, ForwardMode.DECODE) sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(output.next_token_logits) next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, output.next_token_logits return next_token_ids, logits_output.next_token_logits
@torch.inference_mode() @torch.inference_mode()
......
...@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad ...@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
@dataclasses.dataclass @dataclasses.dataclass
class LogitProcessorOutput: class LogitsProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size] # The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size] # The logprobs of the next tokens. shape: [#seq, vocab_size]
...@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module): ...@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested # Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob: if not logits_metadata.return_logprob:
return LogitProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=None, next_token_logprobs=None,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
...@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module): ...@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
else: else:
output_top_logprobs = None output_top_logprobs = None
return LogitProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
...@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module): ...@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
# Remove the last token logprob for the prefill tokens. # Remove the last token logprob for the prefill tokens.
input_token_logprobs = input_token_logprobs[:-1] input_token_logprobs = input_token_logprobs[:-1]
return LogitProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs,
......
import dataclasses
import logging import logging
from typing import Union
import torch import torch
from flashinfer.sampling import ( from flashinfer.sampling import (
...@@ -9,6 +11,8 @@ from flashinfer.sampling import ( ...@@ -9,6 +11,8 @@ from flashinfer.sampling import (
) )
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
# TODO: move this dict to another place # TODO: move this dict to another place
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo ...@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclasses.dataclass
class SampleOutput:
success: torch.Tensor
probs: torch.Tensor
batch_next_token_ids: torch.Tensor
class Sampler(CustomOp): class Sampler(CustomOp):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
return logits
def _get_probs(
self,
logits: torch.Tensor,
sampling_info: SamplingBatchInfo,
is_torch_compile: bool = False,
):
# Post process logits # Post process logits
logits = logits.contiguous() logits = logits.contiguous()
logits.div_(sampling_info.temperatures) logits.div_(sampling_info.temperatures)
if is_torch_compile:
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits.add_(0)
if sampling_info.logit_bias is not None: if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias) logits.add_(sampling_info.logit_bias)
if sampling_info.vocab_mask is not None: if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
logits = sampling_info.penalizer_orchestrator.apply(logits) logits = self._apply_penalties(logits, sampling_info)
probs = torch.softmax(logits, dim=-1) return torch.softmax(logits, dim=-1)
def forward_cuda(
self,
logits: Union[torch.Tensor, LogitsProcessorOutput],
sampling_info: SamplingBatchInfo,
):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits
probs = self._get_probs(logits, sampling_info)
if not global_server_args_dict["disable_flashinfer_sampling"]: if not global_server_args_dict["disable_flashinfer_sampling"]:
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device (max_top_k_round, batch_size), device=probs.device
) )
if sampling_info.min_ps.any(): if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps) probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs( batch_next_token_ids, success = min_p_sampling_from_probs(
...@@ -55,18 +100,23 @@ class Sampler(CustomOp): ...@@ -55,18 +100,23 @@ class Sampler(CustomOp):
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
) )
if not torch.all(success): return SampleOutput(success, probs, batch_next_token_ids)
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
success, batch_next_token_ids, argmax_ids
)
return batch_next_token_ids def forward_native(
self,
logits: Union[torch.Tensor, LogitsProcessorOutput],
sampling_info: SamplingBatchInfo,
):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)
def forward_native(): return SampleOutput(success, probs, batch_next_token_ids)
raise NotImplementedError("Native forward is not implemented yet.")
def top_k_top_p_min_p_sampling_from_probs_torch( def top_k_top_p_min_p_sampling_from_probs_torch(
...@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch( ...@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try: try:
sampled_index = torch.multinomial(probs_sort, num_samples=1) # FIXME: torch.multiomial does not support num_samples = 1
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
:, :1
]
except RuntimeError as e: except RuntimeError as e:
logger.warning(f"Sampling error: {e}") logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros( batch_next_token_ids = torch.zeros(
......
from __future__ import annotations
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -17,7 +19,7 @@ limitations under the License. ...@@ -17,7 +19,7 @@ limitations under the License.
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
...@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache ...@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access # Put some global args for easy access
...@@ -678,11 +684,17 @@ class ScheduleBatch: ...@@ -678,11 +684,17 @@ class ScheduleBatch:
self.top_logprobs_nums.extend(other.top_logprobs_nums) self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
def sample(self, logits: torch.Tensor): def check_sample_results(self, sample_output: SampleOutput):
from sglang.srt.layers.sampler import Sampler if not torch.all(sample_output.success):
probs = sample_output.probs
sampler = Sampler() batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
batch_next_token_ids = sampler(logits, self.sampling_info) probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
return batch_next_token_ids return sample_output.batch_next_token_ids
...@@ -31,7 +31,7 @@ from sglang.global_config import global_config ...@@ -31,7 +31,7 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
...@@ -504,21 +504,29 @@ class ModelTpServer: ...@@ -504,21 +504,29 @@ class ModelTpServer:
if self.model_runner.is_generation: if self.model_runner.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND) sample_output, logits_output = self.model_runner.forward(
next_token_ids = batch.sample(output.next_token_logits) batch, ForwardMode.EXTEND
)
next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
) )
# Move logprobs to cpu # Move logprobs to cpu
if output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[ logits_output.next_token_logprobs = (
torch.arange(len(next_token_ids), device=next_token_ids.device), logits_output.next_token_logprobs[
next_token_ids, torch.arange(
].tolist() len(next_token_ids), device=next_token_ids.device
output.input_token_logprobs = output.input_token_logprobs.tolist() ),
output.normalized_prompt_logprobs = ( next_token_ids,
output.normalized_prompt_logprobs.tolist() ].tolist()
)
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
...@@ -557,12 +565,14 @@ class ModelTpServer: ...@@ -557,12 +565,14 @@ class ModelTpServer:
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob: if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output) self.add_logprob_return_values(
i, req, pt, next_token_ids, logits_output
)
pt += req.extend_input_len pt += req.extend_input_len
else: else:
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
output = self.model_runner.forward(batch, ForwardMode.EXTEND) logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = output.embeddings.tolist() embeddings = logits_output.embeddings.tolist()
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
...@@ -590,7 +600,7 @@ class ModelTpServer: ...@@ -590,7 +600,7 @@ class ModelTpServer:
req: Req, req: Req,
pt: int, pt: int,
next_token_ids: List[int], next_token_ids: List[int],
output: LogitProcessorOutput, output: LogitsProcessorOutput,
): ):
if req.normalized_prompt_logprob is None: if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
...@@ -672,15 +682,17 @@ class ModelTpServer: ...@@ -672,15 +682,17 @@ class ModelTpServer:
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward and sample the next tokens # Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE) sample_output, logits_output = self.model_runner.forward(
next_token_ids = batch.sample(output.next_token_logits) batch, ForwardMode.DECODE
)
next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
) )
# Move logprobs to cpu # Move logprobs to cpu
if output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
next_token_logprobs = output.next_token_logprobs[ next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device), torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids, next_token_ids,
].tolist() ].tolist()
...@@ -706,7 +718,7 @@ class ModelTpServer: ...@@ -706,7 +718,7 @@ class ModelTpServer:
(next_token_logprobs[i], next_token_id) (next_token_logprobs[i], next_token_id)
) )
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(output.output_top_logprobs[i]) req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
......
...@@ -26,16 +26,18 @@ from vllm.distributed.parallel_state import graph_capture ...@@ -26,16 +26,18 @@ from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import ( from sglang.srt.layers.logits_processor import (
LogitProcessorOutput,
LogitsMetadata, LogitsMetadata,
LogitsProcessor, LogitsProcessor,
LogitsProcessorOutput,
) )
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
update_flashinfer_indices, update_flashinfer_indices,
) )
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import monkey_patch_vllm_all_gather from sglang.srt.utils import monkey_patch_vllm_all_gather
...@@ -144,6 +146,10 @@ class CudaGraphRunner: ...@@ -144,6 +146,10 @@ class CudaGraphRunner:
self.flashinfer_kv_indices.clone(), self.flashinfer_kv_indices.clone(),
] ]
# Sampling inputs
vocab_size = model_runner.model_config.vocab_size
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
if use_torch_compile: if use_torch_compile:
...@@ -235,6 +241,7 @@ class CudaGraphRunner: ...@@ -235,6 +241,7 @@ class CudaGraphRunner:
def run_once(): def run_once():
input_metadata = InputMetadata( input_metadata = InputMetadata(
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
sampling_info=self.sampling_info[:bs],
batch_size=bs, batch_size=bs,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -299,27 +306,35 @@ class CudaGraphRunner: ...@@ -299,27 +306,35 @@ class CudaGraphRunner:
self.flashinfer_handlers[bs], self.flashinfer_handlers[bs],
) )
# Sampling inputs
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
# Replay # Replay
torch.cuda.synchronize() torch.cuda.synchronize()
self.graphs[bs].replay() self.graphs[bs].replay()
torch.cuda.synchronize() torch.cuda.synchronize()
output = self.output_buffers[bs] sample_output, logits_output = self.output_buffers[bs]
# Unpad # Unpad
if bs != raw_bs: if bs != raw_bs:
output = LogitProcessorOutput( logits_output = LogitsProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs], next_token_logits=logits_output.next_token_logits[:raw_bs],
next_token_logprobs=None, next_token_logprobs=None,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
input_token_logprobs=None, input_token_logprobs=None,
input_top_logprobs=None, input_top_logprobs=None,
output_top_logprobs=None, output_top_logprobs=None,
) )
sample_output = SampleOutput(
sample_output.success[:raw_bs],
sample_output.probs[:raw_bs],
sample_output.batch_next_token_ids[:raw_bs],
)
# Extract logprobs # Extract logprobs
if batch.return_logprob: if batch.return_logprob:
output.next_token_logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
output.next_token_logits, dim=-1 logits_output.next_token_logits, dim=-1
) )
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
if return_top_logprob: if return_top_logprob:
...@@ -327,8 +342,8 @@ class CudaGraphRunner: ...@@ -327,8 +342,8 @@ class CudaGraphRunner:
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
) )
output.output_top_logprobs = LogitsProcessor.get_top_logprobs( logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata logits_output.next_token_logprobs, logits_metadata
)[1] )[1]
return output return sample_output, logits_output
from __future__ import annotations
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool ...@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -42,6 +45,7 @@ class InputMetadata: ...@@ -42,6 +45,7 @@ class InputMetadata:
"""Store all inforamtion of a forward pass.""" """Store all inforamtion of a forward pass."""
forward_mode: ForwardMode forward_mode: ForwardMode
sampling_info: SamplingBatchInfo
batch_size: int batch_size: int
req_pool_indices: torch.Tensor req_pool_indices: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
...@@ -169,6 +173,7 @@ class InputMetadata: ...@@ -169,6 +173,7 @@ class InputMetadata:
): ):
ret = cls( ret = cls(
forward_mode=forward_mode, forward_mode=forward_mode,
sampling_info=batch.sampling_info,
batch_size=batch.batch_size(), batch_size=batch.batch_size(),
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
...@@ -179,6 +184,8 @@ class InputMetadata: ...@@ -179,6 +184,8 @@ class InputMetadata:
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
) )
ret.sampling_info.prepare_penalties()
ret.compute_positions(batch) ret.compute_positions(batch)
ret.compute_extend_infos(batch) ret.compute_extend_infos(batch)
......
...@@ -21,7 +21,7 @@ import importlib.resources ...@@ -21,7 +21,7 @@ import importlib.resources
import logging import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from typing import Optional, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model ...@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool, MHATokenToKVPool,
...@@ -524,7 +526,11 @@ class ModelRunner: ...@@ -524,7 +526,11 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch): def forward_decode(self, batch: ScheduleBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and not batch.sampling_info.has_bias()
):
return self.cuda_graph_runner.replay(batch) return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch( input_metadata = InputMetadata.from_schedule_batch(
...@@ -573,7 +579,9 @@ class ModelRunner: ...@@ -573,7 +579,9 @@ class ModelRunner:
input_metadata.image_offsets, input_metadata.image_offsets,
) )
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch) return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE: elif forward_mode == ForwardMode.DECODE:
......
...@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import ( ...@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
LoraConfig = None LoraConfig = None
...@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
def sample( return sample_output, logits_output
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
......
...@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module): ...@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.model = CohereModel(config, quant_config) self.model = CohereModel(config, quant_config)
@torch.no_grad() @torch.no_grad()
...@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module): ...@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
positions, positions,
input_metadata, input_metadata,
) )
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig ...@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module): ...@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module): ...@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [ expert_params_mapping = [
......
...@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module): ...@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module): ...@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -45,6 +45,7 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -45,6 +45,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -632,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -632,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -640,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -640,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -37,6 +37,7 @@ from sglang.srt.layers.activation import GeluAndMul ...@@ -37,6 +37,7 @@ from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = GemmaModel(config, quant_config=quant_config) self.model = GemmaModel(config, quant_config=quant_config)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module): ...@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return (sample_output, logits_output)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -37,6 +37,7 @@ from sglang.srt.layers.activation import GeluAndMul ...@@ -37,6 +37,7 @@ from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -346,6 +347,7 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -346,6 +347,7 @@ class Gemma2ForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config) self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -356,9 +358,11 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -356,9 +358,11 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_attention_sliding_window_size(self): def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config) return get_attention_sliding_window_size(self.config)
......
...@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
......
...@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE ...@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -297,6 +298,7 @@ class Grok1ForCausalLM(nn.Module): ...@@ -297,6 +298,7 @@ class Grok1ForCausalLM(nn.Module):
self.model = Grok1Model(config, quant_config=quant_config) self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
# Monkey patch _prepare_weights to load pre-sharded weights # Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
...@@ -313,9 +315,11 @@ class Grok1ForCausalLM(nn.Module): ...@@ -313,9 +315,11 @@ class Grok1ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
self.model = InternLM2Model(config, quant_config) self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata input_ids, hidden_states, self.output.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module): ...@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> LogitProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_module_name(self, name): def get_module_name(self, name):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank ...@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama2 import LlamaModel from sglang.srt.models.llama2 import LlamaModel
...@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module): ...@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
(input_metadata.batch_size, self.config.classification_out_size) (input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device) ).to(input_ids.device)
return LogitProcessorOutput( return LogitsProcessorOutput(
next_token_logits=scores, next_token_logits=scores,
next_token_logprobs=scores, next_token_logprobs=scores,
normalized_prompt_logprobs=scores, normalized_prompt_logprobs=scores,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment