Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
...@@ -503,8 +503,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -503,8 +503,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
dtype: torch.dtype, dtype: torch.dtype,
short_factor: List[float], short_factor: List[float],
long_factor: List[float], long_factor: List[float],
short_mscale: float = 1.0, short_mscale: Optional[float] = None,
long_mscale: float = 1.0, long_mscale: Optional[float] = None,
): ):
super().__init__() super().__init__()
...@@ -523,18 +523,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -523,18 +523,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self.base = base self.base = base
self.short_factor = short_factor self.short_factor = short_factor
self.long_factor = long_factor self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
scale = (self.max_position_embeddings /
self.original_max_position_embeddings)
scale = self.max_position_embeddings / \
self.original_max_position_embeddings
if scale <= 1.0: if scale <= 1.0:
self.scaling_factor = 1.0 scaling_factor = 1.0
else: else:
self.scaling_factor = math.sqrt( scaling_factor = math.sqrt(
1 + math.log(scale) / 1 + math.log(scale) /
math.log(self.original_max_position_embeddings)) math.log(self.original_max_position_embeddings))
if short_mscale is None:
short_mscale = scaling_factor
if long_mscale is None:
long_mscale = scaling_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
short_cache = self._compute_cos_sin_cache( short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale) original_max_position_embeddings, short_factor, short_mscale)
...@@ -571,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -571,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
inv_freq = self._compute_inv_freq(rescale_factors) inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float) t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale * self.scaling_factor cos = freqs.cos() * mscale
sin = freqs.sin() * mscale * self.scaling_factor sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
......
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools import itertools
import warnings import warnings
from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
from math import inf from math import inf
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import msgspec
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
if HAS_TRITON: if HAS_TRITON:
...@@ -19,8 +22,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata, ...@@ -19,8 +22,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SequenceGroupToSample) SequenceGroupToSample)
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SamplerOutput, PromptLogprobs, SampleLogprobs, SequenceOutput)
SequenceOutput)
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
...@@ -35,6 +37,116 @@ else: ...@@ -35,6 +37,116 @@ else:
# (num_token_ids, num_parent_ids) per sequence group. # (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]] SampleResultType = List[Tuple[List[int], List[int]]]
# Types of temporary data structures used for
# computing sample_result
SampleMetadataType = Dict[SamplingType, Tuple[List[int],
List[SequenceGroupToSample]]]
MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
# Encapsulates temporary data structures for computing
# sample_result.
#
# * For multi-step scheduling: must be returned
# by `Sampler.forward()` and used later to compute the pythonized
# sample_result
#
# * For single-step scheduling: consumed immediately
# inside `Sampler.forward()` to compute pythonized sample_result.
@dataclass
class SampleResultArgsType:
sample_metadata: SampleMetadataType
multinomial_samples: MultinomialSamplesType
sample_results_dict: SampleResultsDictType
sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor]
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
# sample result types
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
# Abbreviation of the _sample() return type
SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# Holds either (1) the pythonized sampler result (single-step scheduling)
# or (2) what will be arguments for later deferred pythonization of the
# sampler result (muliti-step scheduling)
deferred_sample_results_args: Optional[SampleResultArgsType] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None
# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
class Sampler(nn.Module): class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs. """Samples the next tokens from the model's outputs.
...@@ -98,6 +210,19 @@ class Sampler(nn.Module): ...@@ -98,6 +210,19 @@ class Sampler(nn.Module):
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
""" """
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args: Args:
logits: (num_tokens, vocab_size). logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling. sampling_metadata: Metadata for sampling.
...@@ -150,7 +275,7 @@ class Sampler(nn.Module): ...@@ -150,7 +275,7 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
sample_results, maybe_sampled_tokens_tensor = _sample( maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs, probs,
logprobs, logprobs,
sampling_metadata, sampling_metadata,
...@@ -160,20 +285,28 @@ class Sampler(nn.Module): ...@@ -160,20 +285,28 @@ class Sampler(nn.Module):
) )
if self.include_gpu_probs_tensor: if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else: else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None on_device_tensors = None
# Get the logprobs query results. # Get the logprobs query results.
prompt_logprobs = None prompt_logprobs = None
sample_logprobs = None sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output: if not sampling_metadata.skip_sampler_cpu_output:
prompt_logprobs, sample_logprobs = _get_logprobs( # Pythonize logprobs now (GPU -> CPU); do not defer.
logprobs, sampling_metadata, sample_results) assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output( return _build_sampler_output(
sample_results, maybe_deferred_sample_results,
sampling_metadata, sampling_metadata,
prompt_logprobs, prompt_logprobs,
sample_logprobs, sample_logprobs,
...@@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer( ...@@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer(
return batch_next_token_ids.view(-1, num_samples) return batch_next_token_ids.view(-1, num_samples)
def get_pythonized_sample_results(
sample_result_args: SampleResultArgsType) -> SampleResultType:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata,
sampling_metadata,
greedy_samples,
multinomial_samples,
beam_search_logprobs,
sample_results_dict,
) = (
sample_result_args.sample_metadata,
sample_result_args.sampling_metadata,
sample_result_args.greedy_samples,
sample_result_args.multinomial_samples,
sample_result_args.beam_search_logprobs,
sample_result_args.sample_results_dict,
)
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
return [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
def _sample_with_torch( def _sample_with_torch(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
...@@ -550,7 +737,19 @@ def _sample_with_torch( ...@@ -550,7 +737,19 @@ def _sample_with_torch(
sampling_tensors: SamplingTensors, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, include_gpu_probs_tensor: bool,
modify_greedy_probs: bool, modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]: ) -> SampleReturnType:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids: Dict[SamplingType, categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: [] List[int]] = {t: []
for t in SamplingType} for t in SamplingType}
...@@ -560,10 +759,11 @@ def _sample_with_torch( ...@@ -560,10 +759,11 @@ def _sample_with_torch(
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: SampleResultsDictType = {}
sample_metadata: Dict[SamplingType, sample_metadata: SampleMetadataType = {}
Tuple[List[int], List[SequenceGroupToSample]]] = {} multinomial_samples: MultinomialSamplesType = {}
multinomial_samples: Dict[SamplingType, torch.Tensor] = {} greedy_samples: Optional[torch.Tensor] = None
beam_search_logprobs: Optional[torch.Tensor] = None
# Create output tensor for sampled token ids. # Create output tensor for sampled token ids.
if include_gpu_probs_tensor: if include_gpu_probs_tensor:
...@@ -638,32 +838,29 @@ def _sample_with_torch( ...@@ -638,32 +838,29 @@ def _sample_with_torch(
else: else:
raise ValueError(f"Unsupported sampling type: {sampling_type}") raise ValueError(f"Unsupported sampling type: {sampling_type}")
# GPU<->CPU sync happens in the loop below. # Encapsulate arguments for computing Pythonized sampler
# This also converts the sample output to Python objects. # results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType(
sampling_metadata=sampling_metadata,
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output: if not sampling_metadata.skip_sampler_cpu_output:
for sampling_type in SamplingType: # GPU<->CPU sync happens here.
if sampling_type not in sample_metadata: # This also converts the sampler output to a Python object.
continue # Return Pythonized sampler result & sampled token ids
(seq_group_id, seq_groups) = sample_metadata[sampling_type] return get_pythonized_sample_results(
if sampling_type == SamplingType.GREEDY: maybe_deferred_args), sampled_token_ids_tensor
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
else: else:
sample_results = [] # Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return sample_results, sampled_token_ids_tensor return (
maybe_deferred_args,
sampled_token_ids_tensor,
)
def _sample_with_triton_kernel( def _sample_with_triton_kernel(
...@@ -755,7 +952,7 @@ def _sample( ...@@ -755,7 +952,7 @@ def _sample(
sampling_tensors: SamplingTensors, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, include_gpu_probs_tensor: bool,
modify_greedy_probs: bool, modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]: ) -> SampleReturnType:
""" """
Args: Args:
probs: (num_query_tokens_in_batch, num_vocab) probs: (num_query_tokens_in_batch, num_vocab)
...@@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: ...@@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return result.sum(1).add_(1) return result.sum(1).add_(1)
def _get_logprobs( def get_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_results: SampleResultType, sample_results: SampleResultType,
...@@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, ...@@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output( def _build_sampler_output(
sample_results: SampleResultType, maybe_deferred_sample_results: MaybeDeferredSampleResultType,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
sample_logprobs: Optional[List[SampleLogprobs]], sample_logprobs: Optional[List[SampleLogprobs]],
...@@ -1143,14 +1340,21 @@ def _build_sampler_output( ...@@ -1143,14 +1340,21 @@ def _build_sampler_output(
speculative decoding rejection sampling. speculative decoding rejection sampling.
""" """
sampler_output: List[CompletionSequenceGroupOutput] = [] sampler_output: List[CompletionSequenceGroupOutput] = []
if not skip_sampler_cpu_output:
if skip_sampler_cpu_output:
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
deferred_sample_results_args = maybe_deferred_sample_results
else:
assert prompt_logprobs is not None assert prompt_logprobs is not None
assert sample_logprobs is not None assert sample_logprobs is not None
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
deferred_sample_results_args = None
for (seq_group, sample_result, group_prompt_logprobs, for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups, group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs, maybe_deferred_sample_results,
sample_logprobs): prompt_logprobs, sample_logprobs):
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = [] seq_outputs: List[SequenceOutput] = []
...@@ -1176,7 +1380,7 @@ def _build_sampler_output( ...@@ -1176,7 +1380,7 @@ def _build_sampler_output(
sampled_token_probs=sampled_token_probs, sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor, logprobs=logprobs_tensor,
) deferred_sample_results_args=deferred_sample_results_args)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
...@@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module): ...@@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_incorrect_input( def _raise_if_incorrect_input(
self, self,
target_probs: torch.Tensor, target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None, draft_probs: Optional[torch.Tensor] = None,
) -> None: ) -> None:
self._raise_if_incorrect_shape(target_probs, draft_token_ids, self._raise_if_incorrect_shape(target_with_bonus_probs,
bonus_token_ids, draft_probs) draft_token_ids, bonus_token_ids,
self._raise_if_incorrect_dtype(target_probs, draft_token_ids, draft_probs)
bonus_token_ids, draft_probs) self._raise_if_incorrect_dtype(target_with_bonus_probs,
self._raise_if_inconsistent_device(target_probs, draft_token_ids, draft_token_ids, bonus_token_ids,
bonus_token_ids, draft_probs) draft_probs)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], self._raise_if_inconsistent_device(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
draft_token_ids, bonus_token_ids) draft_token_ids, bonus_token_ids)
def _raise_if_incorrect_shape( def _raise_if_incorrect_shape(
self, self,
target_probs: torch.Tensor, target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None, draft_probs: Optional[torch.Tensor] = None,
) -> None: ) -> None:
(target_batch_size, num_target_probs, (target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape target_vocab_size) = target_with_bonus_probs.shape
# Does not count the extra token
num_target_probs -= 1
# validate the shape of draft token ids. # validate the shape of draft token ids.
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
...@@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module): ...@@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_incorrect_dtype( def _raise_if_incorrect_dtype(
self, self,
target_probs: torch.Tensor, target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None, draft_probs: Optional[torch.Tensor] = None,
) -> None: ) -> None:
assert target_probs.dtype == self.probs_dtype assert target_with_bonus_probs.dtype == self.probs_dtype
assert draft_token_ids.dtype == self.token_id_dtype assert draft_token_ids.dtype == self.token_id_dtype
assert bonus_token_ids.dtype == self.token_id_dtype assert bonus_token_ids.dtype == self.token_id_dtype
if draft_probs is not None: if draft_probs is not None:
...@@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module): ...@@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_inconsistent_device( def _raise_if_inconsistent_device(
self, self,
target_probs: torch.Tensor, target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None, draft_probs: Optional[torch.Tensor] = None,
) -> None: ) -> None:
devices = [ devices = [
t.device for t in t.device for t in [
[target_probs, bonus_token_ids, draft_probs, draft_token_ids] target_with_bonus_probs, bonus_token_ids, draft_probs,
if t is not None draft_token_ids
] if t is not None
] ]
assert all([devices[0] == device for device in devices]) assert all([devices[0] == device for device in devices])
...@@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): ...@@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
target_probs: torch.Tensor, target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
...@@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): ...@@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
target_probs: torch.Tensor, target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
......
...@@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
def forward( def forward(
self, self,
target_probs: torch.Tensor, target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
...@@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds # Only perform shape/dtype/device checking in strict mode, as it adds
# overhead. # overhead.
if self._strict_mode: if self._strict_mode:
self._raise_if_incorrect_input(target_probs, draft_token_ids, self._raise_if_incorrect_input(target_with_bonus_probs,
bonus_token_ids) draft_token_ids, bonus_token_ids)
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs, accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids) draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs) recovered_token_ids = self._replacement_token_ids(target_probs)
......
...@@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
...@@ -351,7 +352,10 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -351,7 +352,10 @@ class VocabParallelEmbedding(torch.nn.Module):
param.weight_type = loaded_weight.item() param.weight_type = loaded_weight.item()
return return
elif isinstance(param, UninitializedParameter): elif isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) shape = list(loaded_weight.shape)
if output_dim is not None:
shape[output_dim] = shape[output_dim] // self.tp_size
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
# If parameter does not have output dim, then it should # If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq). # be copied onto all gpus (e.g. g_idx for act_order gptq).
...@@ -367,10 +371,12 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -367,10 +371,12 @@ class VocabParallelEmbedding(torch.nn.Module):
# If param packed on the same dim we are sharding on, then # If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor. # need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim: if packed_dim is not None and packed_dim == output_dim:
packed_factor = param.packed_factor if isinstance(
param, BasevLLMParameter) else param.pack_factor
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
param.pack_factor) param.packed_factor)
start_idx = start_idx // param.pack_factor start_idx = start_idx // packed_factor
shard_size = shard_size // param.pack_factor shard_size = shard_size // packed_factor
else: else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size assert loaded_weight.shape[output_dim] == self.org_vocab_size
......
...@@ -774,7 +774,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -774,7 +774,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return pt_weights_iterator(hf_weights_files) return pt_weights_iterator(hf_weights_files)
def _get_quantized_weights_iterator( def _get_quantized_weights_iterator(
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool self,
model_name_or_path: str,
revision: Optional[str],
pre_quant: bool,
load_8bit: bool,
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]: Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization, """Get an iterator to the model weights with bitsandbytes quantization,
...@@ -783,11 +787,9 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -783,11 +787,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed # only load the bitsandbytes module when needed
try: try:
import bitsandbytes import bitsandbytes
from bitsandbytes.functional import QuantState
if bitsandbytes.__version__ < "0.42.0": if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please " raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.") "install bitsandbytes>=0.42.0.")
from bitsandbytes.functional import quantize_4bit
except ImportError as err: except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.42.0 via " raise ImportError("Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use " "`pip install bitsandbytes>=0.42.0` to use "
...@@ -796,80 +798,111 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -796,80 +798,111 @@ class BitsAndBytesModelLoader(BaseModelLoader):
hf_weights_files, use_safetensors = self._prepare_weights( hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision) model_name_or_path, revision)
quant_state_dict = {} quant_state_dict: Dict[str, Any] = {}
def quantized_checkpoint() -> Generator:
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
continue
# TODO: only nf4 quantization is supported for now
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
raise NotImplementedError(
"Only bitsandbytes_nf4 quantization"
f"is supported for now. {weight_name} is fp4 quantized"
)
temp_state_dict[weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight if pre_quant:
def _parse_quant_state(param_name: str, if load_8bit:
temp_state_dict: Dict) -> QuantState: return self._quantized_8bit_generator(
quant_state = {} hf_weights_files, use_safetensors,
for k in temp_state_dict: quant_state_dict), quant_state_dict
if param_name + "." in k: else:
quant_state[k] = temp_state_dict[k] return self._quantized_4bit_generator(
# bitsandbytes library requires hf_weights_files, use_safetensors,
# weight.quant_state.bitsandbytes__nf4 in CPU quant_state_dict), quant_state_dict
quant_state[param_name +
".quant_state.bitsandbytes__nf4"] = quant_state[
param_name +
".quant_state.bitsandbytes__nf4"].cpu().data
return QuantState.from_dict(quant_state, device="cuda")
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
continue
if weight_name + ".quant_state.bitsandbytes__nf4" \
in temp_state_dict:
quant_state = _parse_quant_state(weight_name,
temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight",
".qweight"), weight_tensor
else:
yield weight_name, weight_tensor
def generator() -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4")
quant_state_dict[weight_name] = quant_state
else:
processed_weight = weight_tensor
yield weight_name, processed_weight return self._unquantized_generator(hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
if pre_quant: def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
return quantized_checkpoint(), quant_state_dict quant_state_dict) -> Generator:
return generator(), quant_state_dict for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if not weight_name.lower().endswith(".scb"):
continue
weight_key = weight_name.lower().replace(".scb", ".qweight")
quant_state_dict[weight_key] = weight_tensor
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if not weight_name.endswith(".weight"):
continue
qweight_name = weight_name.replace(".weight", ".qweight")
if qweight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor
else:
yield weight_name, weight_tensor
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import QuantState
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in weight_name:
temp_state_dict[weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: Dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state, device="cuda")
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
in temp_state_dict) or \
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
in temp_state_dict):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight", ".qweight"), weight_tensor
else:
yield weight_name, weight_tensor
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4")
quant_state_dict[weight_name] = quant_state
else:
processed_weight = weight_tensor
yield weight_name, processed_weight
def _load_weights(self, model_config: ModelConfig, def _load_weights(self, model_config: ModelConfig,
model: nn.Module) -> None: model: nn.Module) -> None:
...@@ -886,16 +919,26 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -886,16 +919,26 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger.info("Loading weights with BitsAndBytes quantization. " logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...") " May take a while ...")
is_quantized_checkpoint = False
quant_config = getattr(model_config.hf_config, "quantization_config", quant_config = getattr(model_config.hf_config, "quantization_config",
None) None)
if quant_config is not None and quant_config.get(
'quant_method') == "bitsandbytes": pre_quant = False
is_quantized_checkpoint = True if quant_config is not None:
quant_method = quant_config.get('quant_method')
if quant_method == "bitsandbytes":
pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
"quantization")
load_8bit = False
if pre_quant:
load_8bit = quant_config.get('load_in_8bit', False)
qweight_iterator, quant_state_dict = \ qweight_iterator, quant_state_dict = \
self._get_quantized_weights_iterator( self._get_quantized_weights_iterator(
model_config.model, model_config.revision, is_quantized_checkpoint) model_config.model, model_config.revision, pre_quant, load_8bit)
model.load_weights(qweight_iterator) model.load_weights(qweight_iterator)
...@@ -945,6 +988,10 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -945,6 +988,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
offsets = np.concatenate(([0], np.cumsum(num_elements))) offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets}) set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
......
"""Utilities for selecting and loading neuron models.""" """Utilities for selecting and loading neuron models."""
import importlib import importlib
import os import os
from typing import Dict, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,9 +10,9 @@ from transformers import PretrainedConfig ...@@ -10,9 +10,9 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
TORCH_DTYPE_TO_NEURON_AMP = { TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32", "auto": "f32",
...@@ -82,8 +82,7 @@ class NeuronCasualLM(nn.Module): ...@@ -82,8 +82,7 @@ class NeuronCasualLM(nn.Module):
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
split_model_dir = f"{model_name_or_path}-split" split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path, if _is_pretrained_neuron_checkpoint(model_name_or_path):
"pytorch_model.bin")):
split_model_dir = model_name_or_path split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"): elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls_name) hf_model_cls = getattr(transformers, hf_model_cls_name)
...@@ -98,6 +97,23 @@ class NeuronCasualLM(nn.Module): ...@@ -98,6 +97,23 @@ class NeuronCasualLM(nn.Module):
self.model.to_neuron() self.model.to_neuron()
def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
# Checking if the neuron checkpoint is saved in the old format.
if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
return True
# Checking if the neuron checkpoint is saved in the new format.
pretrained_split_files = ["config.json", "generation_config.json"]
pretrained_split_format = ".safetensors"
for file in pretrained_split_files:
file_path = os.path.join(model_name_or_path, file)
if not os.path.isfile(file_path):
return False
for file in os.listdir(model_name_or_path):
if file.endswith(pretrained_split_format):
return True
return False
def _get_model_architecture(config: PretrainedConfig) -> str: def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
...@@ -109,28 +125,75 @@ def _get_model_architecture(config: PretrainedConfig) -> str: ...@@ -109,28 +125,75 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
f"{list(_NEURON_SUPPORTED_MODELS.keys())}") f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
quant_config = dict(
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
quantize_method="vector_dynamic")
neuron_quantization_config_builder = lambda quant: get_quantization_config(
quant).from_config(quant_config).get_quant_method(None, "")
# TODO: Add Paged attention config to the default neuron arguments.
default_neuron_args = dict(
collectives_layout=LAYOUT_BSH,
attention_layout=LAYOUT_BSH,
fuse_qkv=True,
quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None,
continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization))
return default_neuron_args
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
from transformers_neuronx.config import NeuronConfig
overridden_neuron_config = overridden_neuron_config or {}
default_neuron_config.update(overridden_neuron_config)
return NeuronConfig(**default_neuron_config)
def get_neuron_model(model_config: ModelConfig, def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig) -> nn.Module:
from transformers_neuronx.config import (ContinuousBatchingConfig,
NeuronConfig)
# Create a model instance. # Create a model instance.
model = NeuronCasualLM(model_config.hf_config) model = NeuronCasualLM(model_config.hf_config)
continuous_batching_config = ContinuousBatchingConfig( default_neuron_config_args = _get_default_neuron_config(
batch_size_for_shared_caches=scheduler_config.max_num_seqs) model_config, parallel_config, scheduler_config)
neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config) neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights( model.load_weights(model_config.model,
model_config.model, tp_degree=parallel_config.tensor_parallel_size,
tp_degree=parallel_config.tensor_parallel_size, amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], neuron_config=neuron_config,
neuron_config=neuron_config, context_length_estimate=context_length_estimates,
context_length_estimate=[scheduler_config.max_model_len], n_positions=n_positions,
n_positions=[scheduler_config.max_model_len], batch_size=scheduler_config.max_num_seqs)
batch_size=scheduler_config.max_num_seqs)
return model.eval() return model.eval()
...@@ -15,9 +15,8 @@ from vllm.config import DeviceConfig, ModelConfig ...@@ -15,9 +15,8 @@ from vllm.config import DeviceConfig, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor, from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states) _prune_hidden_states)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -42,11 +42,11 @@ def get_model_architecture( ...@@ -42,11 +42,11 @@ def get_model_architecture(
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors"]
if (model_config.quantization is not None if (model_config.quantization is not None
and model_config.quantization != "fp8" and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures): and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"] architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures) return ModelRegistry.resolve_model_cls(architectures)
......
...@@ -22,6 +22,7 @@ _GENERATION_MODELS = { ...@@ -22,6 +22,7 @@ _GENERATION_MODELS = {
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
...@@ -49,6 +50,7 @@ _GENERATION_MODELS = { ...@@ -49,6 +50,7 @@ _GENERATION_MODELS = {
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
...@@ -63,6 +65,7 @@ _GENERATION_MODELS = { ...@@ -63,6 +65,7 @@ _GENERATION_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"), "EAGLEModel": ("eagle", "EAGLE"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"GraniteForCausalLM": ("granite", "GraniteForCausalLM")
} }
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {
......
...@@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.deepspeedfp import ( from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter) DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig from vllm.transformers_utils.configs.arctic import ArcticConfig
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -40,12 +40,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -40,12 +40,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
......
...@@ -34,12 +34,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -34,12 +34,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -10,15 +10,23 @@ from transformers import Blip2VisionConfig, BlipVisionConfig ...@@ -10,15 +10,23 @@ from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention from transformers.models.blip.modeling_blip import BlipAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
...@@ -154,6 +162,77 @@ class BlipVisionEmbeddings(nn.Module): ...@@ -154,6 +162,77 @@ class BlipVisionEmbeddings(nn.Module):
return embeddings return embeddings
class BlipParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)
return attn_output, None
class BlipMLP(nn.Module): class BlipMLP(nn.Module):
def __init__(self, def __init__(self,
...@@ -188,7 +267,16 @@ class BlipEncoderLayer(nn.Module): ...@@ -188,7 +267,16 @@ class BlipEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.self_attn = BlipAttention(config) # fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config,
quant_config=quant_config)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config) self.mlp = BlipMLP(config, quant_config=quant_config)
......
...@@ -13,13 +13,13 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs ...@@ -13,13 +13,13 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData) SequenceData)
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
...@@ -40,13 +40,13 @@ BLIP2_IMAGE_TOKEN_ID = 50265 ...@@ -40,13 +40,13 @@ BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict): class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Blip2ImageEmbeddingInputs(TypedDict): class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: torch.Tensor data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
...@@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return Blip2ImagePixelInputs( return Blip2ImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
...@@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return Blip2ImageEmbeddingInputs( return Blip2ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=image_embeds,
...@@ -707,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -707,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
use_default_weight_loading = False use_default_weight_loading = False
if "vision" in name: if "vision" in name:
if self.vision_model is not None: if self.vision_model is not None:
# We only do sharding for language model and # BlipVisionModel does not need sharding
# not vision model for now.
use_default_weight_loading = True use_default_weight_loading = True
else: else:
for (param_name, weight_name, for (param_name, weight_name,
......
...@@ -36,12 +36,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -36,12 +36,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
......
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -33,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -33,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData) SequenceData)
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
...@@ -53,7 +53,7 @@ CHAMELEON_SEP_TOKEN_ID = 8710 ...@@ -53,7 +53,7 @@ CHAMELEON_SEP_TOKEN_ID = 8710
class ChameleonImagePixelInputs(TypedDict): class ChameleonImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
def get_max_chameleon_image_tokens(ctx: InputContext): def get_max_chameleon_image_tokens(ctx: InputContext):
...@@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return ChameleonImagePixelInputs( return ChameleonImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
......
...@@ -22,12 +22,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -22,12 +22,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
......
"""Minimal implementation of CLIPVisionModel intended to be only used """Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from array import array from array import array
from typing import Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention from transformers.models.clip.modeling_clip import CLIPSdpaAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -20,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -20,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0 assert image_size % patch_size == 0
...@@ -84,7 +92,7 @@ def input_processor_for_clip( ...@@ -84,7 +92,7 @@ def input_processor_for_clip(
llm_inputs: LLMInputs, llm_inputs: LLMInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[Union[int, List[int]]] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
...@@ -160,6 +168,78 @@ class CLIPVisionEmbeddings(nn.Module): ...@@ -160,6 +168,78 @@ class CLIPVisionEmbeddings(nn.Module):
return embeddings return embeddings
class CLIPParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output, None
class CLIPMLP(nn.Module): class CLIPMLP(nn.Module):
def __init__(self, def __init__(self,
...@@ -192,7 +272,13 @@ class CLIPEncoderLayer(nn.Module): ...@@ -192,7 +272,13 @@ class CLIPEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.self_attn = CLIPAttention(config) num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config) self.mlp = CLIPMLP(config, quant_config=quant_config)
...@@ -217,7 +303,7 @@ class CLIPEncoderLayer(nn.Module): ...@@ -217,7 +303,7 @@ class CLIPEncoderLayer(nn.Module):
class CLIPEncoder(nn.Module): class CLIPEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`]. attention layers. Each layer is a [`CLIPEncoderLayer`].
Args: Args:
...@@ -291,6 +377,10 @@ class CLIPVisionModel(nn.Module): ...@@ -291,6 +377,10 @@ class CLIPVisionModel(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None): num_hidden_layers_override: Optional[int] = None):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = CLIPVisionTransformer( self.vision_model = CLIPVisionTransformer(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
...@@ -304,7 +394,15 @@ class CLIPVisionModel(nn.Module): ...@@ -304,7 +394,15 @@ class CLIPVisionModel(nn.Module):
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
...@@ -318,7 +416,16 @@ class CLIPVisionModel(nn.Module): ...@@ -318,7 +416,16 @@ class CLIPVisionModel(nn.Module):
if layer_idx >= layer_count: if layer_idx >= layer_count:
continue continue
param = params_dict[name] for (param_name, weight_name, shard_id) in stacked_params_mapping:
weight_loader = getattr(param, "weight_loader", if weight_name not in name:
default_weight_loader) continue
weight_loader(param, loaded_weight)
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -38,14 +38,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -38,14 +38,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader) default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
@torch.compile @torch.compile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment