Unverified Commit 428dd144 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Core] Logprobs support in Multi-step (#7652)

parent 4abed65c
...@@ -9,7 +9,8 @@ from vllm.config import CacheConfig, ModelConfig ...@@ -9,7 +9,8 @@ from vllm.config import CacheConfig, ModelConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
get_open_port, make_async) get_open_port, make_async)
......
...@@ -12,7 +12,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable ...@@ -12,7 +12,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
from vllm.executor.msgspec_utils import encode_hook from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method, from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id, get_ip, get_open_port, get_vllm_instance_id,
make_async) make_async)
......
...@@ -10,7 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase ...@@ -10,7 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.executor.tpu_executor import TPUExecutor from vllm.executor.tpu_executor import TPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
......
...@@ -5,7 +5,8 @@ import torch ...@@ -5,7 +5,8 @@ import torch
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
......
...@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
......
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools import itertools
import warnings import warnings
from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
from math import inf from math import inf
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import msgspec
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
if HAS_TRITON: if HAS_TRITON:
...@@ -19,8 +22,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata, ...@@ -19,8 +22,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SequenceGroupToSample) SequenceGroupToSample)
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SamplerOutput, PromptLogprobs, SampleLogprobs, SequenceOutput)
SequenceOutput)
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
...@@ -35,6 +37,116 @@ else: ...@@ -35,6 +37,116 @@ else:
# (num_token_ids, num_parent_ids) per sequence group. # (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]] SampleResultType = List[Tuple[List[int], List[int]]]
# Types of temporary data structures used for
# computing sample_result
SampleMetadataType = Dict[SamplingType, Tuple[List[int],
List[SequenceGroupToSample]]]
MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
# Encapsulates temporary data structures for computing
# sample_result.
#
# * For multi-step scheduling: must be returned
# by `Sampler.forward()` and used later to compute the pythonized
# sample_result
#
# * For single-step scheduling: consumed immediately
# inside `Sampler.forward()` to compute pythonized sample_result.
@dataclass
class SampleResultArgsType:
sample_metadata: SampleMetadataType
multinomial_samples: MultinomialSamplesType
sample_results_dict: SampleResultsDictType
sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor]
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
# sample result types
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
# Abbreviation of the _sample() return type
SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# Holds either (1) the pythonized sampler result (single-step scheduling)
# or (2) what will be arguments for later deferred pythonization of the
# sampler result (muliti-step scheduling)
deferred_sample_results_args: Optional[SampleResultArgsType] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None
# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
class Sampler(nn.Module): class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs. """Samples the next tokens from the model's outputs.
...@@ -98,6 +210,19 @@ class Sampler(nn.Module): ...@@ -98,6 +210,19 @@ class Sampler(nn.Module):
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
""" """
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args: Args:
logits: (num_tokens, vocab_size). logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling. sampling_metadata: Metadata for sampling.
...@@ -150,7 +275,7 @@ class Sampler(nn.Module): ...@@ -150,7 +275,7 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
sample_results, maybe_sampled_tokens_tensor = _sample( maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs, probs,
logprobs, logprobs,
sampling_metadata, sampling_metadata,
...@@ -160,20 +285,28 @@ class Sampler(nn.Module): ...@@ -160,20 +285,28 @@ class Sampler(nn.Module):
) )
if self.include_gpu_probs_tensor: if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else: else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None on_device_tensors = None
# Get the logprobs query results. # Get the logprobs query results.
prompt_logprobs = None prompt_logprobs = None
sample_logprobs = None sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output: if not sampling_metadata.skip_sampler_cpu_output:
prompt_logprobs, sample_logprobs = _get_logprobs( # Pythonize logprobs now (GPU -> CPU); do not defer.
logprobs, sampling_metadata, sample_results) assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output( return _build_sampler_output(
sample_results, maybe_deferred_sample_results,
sampling_metadata, sampling_metadata,
prompt_logprobs, prompt_logprobs,
sample_logprobs, sample_logprobs,
...@@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer( ...@@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer(
return batch_next_token_ids.view(-1, num_samples) return batch_next_token_ids.view(-1, num_samples)
def get_pythonized_sample_results(
sample_result_args: SampleResultArgsType) -> SampleResultType:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata,
sampling_metadata,
greedy_samples,
multinomial_samples,
beam_search_logprobs,
sample_results_dict,
) = (
sample_result_args.sample_metadata,
sample_result_args.sampling_metadata,
sample_result_args.greedy_samples,
sample_result_args.multinomial_samples,
sample_result_args.beam_search_logprobs,
sample_result_args.sample_results_dict,
)
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
return [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
def _sample_with_torch( def _sample_with_torch(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
...@@ -550,7 +737,19 @@ def _sample_with_torch( ...@@ -550,7 +737,19 @@ def _sample_with_torch(
sampling_tensors: SamplingTensors, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, include_gpu_probs_tensor: bool,
modify_greedy_probs: bool, modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]: ) -> SampleReturnType:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids: Dict[SamplingType, categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: [] List[int]] = {t: []
for t in SamplingType} for t in SamplingType}
...@@ -560,10 +759,11 @@ def _sample_with_torch( ...@@ -560,10 +759,11 @@ def _sample_with_torch(
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: SampleResultsDictType = {}
sample_metadata: Dict[SamplingType, sample_metadata: SampleMetadataType = {}
Tuple[List[int], List[SequenceGroupToSample]]] = {} multinomial_samples: MultinomialSamplesType = {}
multinomial_samples: Dict[SamplingType, torch.Tensor] = {} greedy_samples: Optional[torch.Tensor] = None
beam_search_logprobs: Optional[torch.Tensor] = None
# Create output tensor for sampled token ids. # Create output tensor for sampled token ids.
if include_gpu_probs_tensor: if include_gpu_probs_tensor:
...@@ -638,32 +838,29 @@ def _sample_with_torch( ...@@ -638,32 +838,29 @@ def _sample_with_torch(
else: else:
raise ValueError(f"Unsupported sampling type: {sampling_type}") raise ValueError(f"Unsupported sampling type: {sampling_type}")
# GPU<->CPU sync happens in the loop below. # Encapsulate arguments for computing Pythonized sampler
# This also converts the sample output to Python objects. # results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType(
sampling_metadata=sampling_metadata,
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output: if not sampling_metadata.skip_sampler_cpu_output:
for sampling_type in SamplingType: # GPU<->CPU sync happens here.
if sampling_type not in sample_metadata: # This also converts the sampler output to a Python object.
continue # Return Pythonized sampler result & sampled token ids
(seq_group_id, seq_groups) = sample_metadata[sampling_type] return get_pythonized_sample_results(
if sampling_type == SamplingType.GREEDY: maybe_deferred_args), sampled_token_ids_tensor
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
else: else:
sample_results = [] # Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return sample_results, sampled_token_ids_tensor return (
maybe_deferred_args,
sampled_token_ids_tensor,
)
def _sample_with_triton_kernel( def _sample_with_triton_kernel(
...@@ -755,7 +952,7 @@ def _sample( ...@@ -755,7 +952,7 @@ def _sample(
sampling_tensors: SamplingTensors, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, include_gpu_probs_tensor: bool,
modify_greedy_probs: bool, modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]: ) -> SampleReturnType:
""" """
Args: Args:
probs: (num_query_tokens_in_batch, num_vocab) probs: (num_query_tokens_in_batch, num_vocab)
...@@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: ...@@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return result.sum(1).add_(1) return result.sum(1).add_(1)
def _get_logprobs( def get_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_results: SampleResultType, sample_results: SampleResultType,
...@@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, ...@@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output( def _build_sampler_output(
sample_results: SampleResultType, maybe_deferred_sample_results: MaybeDeferredSampleResultType,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
sample_logprobs: Optional[List[SampleLogprobs]], sample_logprobs: Optional[List[SampleLogprobs]],
...@@ -1143,14 +1340,21 @@ def _build_sampler_output( ...@@ -1143,14 +1340,21 @@ def _build_sampler_output(
speculative decoding rejection sampling. speculative decoding rejection sampling.
""" """
sampler_output: List[CompletionSequenceGroupOutput] = [] sampler_output: List[CompletionSequenceGroupOutput] = []
if not skip_sampler_cpu_output:
if skip_sampler_cpu_output:
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
deferred_sample_results_args = maybe_deferred_sample_results
else:
assert prompt_logprobs is not None assert prompt_logprobs is not None
assert sample_logprobs is not None assert sample_logprobs is not None
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
deferred_sample_results_args = None
for (seq_group, sample_result, group_prompt_logprobs, for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups, group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs, maybe_deferred_sample_results,
sample_logprobs): prompt_logprobs, sample_logprobs):
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = [] seq_outputs: List[SequenceOutput] = []
...@@ -1176,7 +1380,7 @@ def _build_sampler_output( ...@@ -1176,7 +1380,7 @@ def _build_sampler_output(
sampled_token_probs=sampled_token_probs, sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor, logprobs=logprobs_tensor,
) deferred_sample_results_args=deferred_sample_results_args)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
...@@ -10,9 +10,8 @@ from transformers import PretrainedConfig ...@@ -10,9 +10,8 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
TORCH_DTYPE_TO_NEURON_AMP = { TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32", "auto": "f32",
......
...@@ -15,9 +15,8 @@ from vllm.config import DeviceConfig, ModelConfig ...@@ -15,9 +15,8 @@ from vllm.config import DeviceConfig, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor, from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states) _prune_hidden_states)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.deepspeedfp import ( from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter) DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig from vllm.transformers_utils.configs.arctic import ArcticConfig
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -38,12 +38,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -38,12 +38,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
......
...@@ -34,12 +34,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -34,12 +34,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -13,13 +13,13 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -13,13 +13,13 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData) SequenceData)
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
......
...@@ -34,12 +34,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -34,12 +34,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
......
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -33,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -33,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData) SequenceData)
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
......
...@@ -20,12 +20,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -20,12 +20,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
......
...@@ -38,14 +38,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -38,14 +38,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader) default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
@torch.compile @torch.compile
......
...@@ -17,13 +17,13 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -17,13 +17,13 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
......
...@@ -43,12 +43,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -43,12 +43,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
......
...@@ -43,12 +43,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -43,12 +43,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
......
...@@ -5,12 +5,13 @@ import torch.nn as nn ...@@ -5,12 +5,13 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment