Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
...@@ -54,6 +54,16 @@ class MambaStateDtypeCalculator: ...@@ -54,6 +54,16 @@ class MambaStateDtypeCalculator:
return (conv_state_dtype, temporal_state_dtype) return (conv_state_dtype, temporal_state_dtype)
@classmethod
def short_conv_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
return (conv_state_dtype, )
class MambaStateShapeCalculator: class MambaStateShapeCalculator:
...@@ -122,6 +132,20 @@ class MambaStateShapeCalculator: ...@@ -122,6 +132,20 @@ class MambaStateShapeCalculator:
tp_world_size), head_dim, state_size) tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return (conv_state_shape, )
@classmethod @classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for """Compute the increase in group numbers to account for
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata)
@CustomOp.register("short_conv")
class ShortConv(MambaBase, CustomOp):
def __init__(self,
config,
dim: int,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.conv_dim = dim
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias
self.conv = ColumnParallelLinear(
input_size=self.L_cache,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[dim] * 3,
bias=self.bias,
prefix=f"{prefix}.in_proj",
)
self.out_proj = RowParallelLinear(
input_size=dim,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.out_proj",
)
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self.kv_cache = [(torch.tensor([]), )]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
return
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
torch.ops.vllm.short_conv(
hidden_states,
output,
self.prefix,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
conv_metadata: ShortConvAttentionMetadata,
):
forward_context = get_forward_context()
# ShortConvAttentionMetadata contains metadata necessary for the
# short_conv triton kernels to operate in continuous batching and in
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
BCx, _ = self.in_proj(hidden_states)
B, C, x = BCx.chunk(3, dim=-1)
conv_weights = self.conv.weight.view(self.conv.weight.size(0),
self.conv.weight.size(2))
if attn_metadata is None:
# V1 profile run
Bx = (B * x).contiguous()
hidden_states = C * Bx
contextualized_states, _ = self.out_proj(hidden_states)
return contextualized_states
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_decodes + num_prefill_tokens
# NOTE: V1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
B_d, B_p = torch.split(
B[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
C_d, C_p = torch.split(
C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
x_d, x_p = torch.split(
x[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
conv_output_list = []
if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
conv_metadata)
Bx = causal_conv1d_fn(Bx_p,
conv_weights,
self.conv.bias,
activation=None,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=conv_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
y = C_p * Bx
conv_output_list.append(y)
if has_decode:
Bx_d = (B_d * x_d).contiguous()
Bx = causal_conv1d_update(
Bx_d,
conv_state,
conv_weights,
self.conv.bias,
activation=None,
conv_state_indices=state_indices_tensor_d)
y = C_d * Bx
conv_output_list.insert(0, y)
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(conv_output_list)
# Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.short_conv_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...]]:
return MambaStateShapeCalculator.short_conv_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.conv_dim,
conv_kernel=self.L_cache,
)
@property
def mamba_type(self) -> str:
return "short_conv"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
return ShortConvAttentionBackend
def short_conv(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
conv_metadata=None)
def short_conv_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="short_conv",
op_func=short_conv,
mutates_args=["output"],
fake_impl=short_conv_fake,
dispatch_key=current_platform.dispatch_key,
)
...@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set ...@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from itertools import groupby from itertools import groupby
from typing import Callable, Optional, TypeVar, Union from typing import Callable, Optional, TypeVar, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -13,16 +13,15 @@ import torch.nn.functional as F ...@@ -13,16 +13,15 @@ import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig, PoolerConfig from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501 from vllm.logger import init_logger
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils import resolve_obj_by_qualname from vllm.utils import current_stream, resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
logger = init_logger(__name__)
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingFn = Callable[ PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]] Union[torch.Tensor, list[torch.Tensor]]]
...@@ -126,16 +125,11 @@ def get_prompt_lens( ...@@ -126,16 +125,11 @@ def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens return pooling_metadata.prompt_lens
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states[0].device).prompt_lens
def get_prompt_token_ids( def get_prompt_token_ids(
pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
assert pooling_metadata.prompt_token_ids is not None, ( assert pooling_metadata.prompt_token_ids is not None, (
"Please set `requires_token_ids=True` in `get_pooling_updates`") "Please set `requires_token_ids=True` in `get_pooling_updates`")
...@@ -144,17 +138,9 @@ def get_prompt_token_ids( ...@@ -144,17 +138,9 @@ def get_prompt_token_ids(
for i, num in enumerate(pooling_metadata.prompt_lens) for i, num in enumerate(pooling_metadata.prompt_lens)
] ]
return [
torch.tensor(seq_data_i.prompt_token_ids)
for seq_data_i in pooling_metadata.seq_data.values()
]
def get_pooling_params( def get_pooling_params(
pooling_metadata: PoolingMetadata) -> list[PoolingParams]: pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
if isinstance(pooling_metadata, V0PoolingMetadata):
pooling_params = [p for _, p in pooling_metadata.seq_groups]
else:
pooling_params = pooling_metadata.pooling_params pooling_params = pooling_metadata.pooling_params
return pooling_params return pooling_params
...@@ -172,6 +158,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: ...@@ -172,6 +158,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
def get_classification_activation_function(config: PretrainedConfig): def get_classification_activation_function(config: PretrainedConfig):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type = getattr(config, "problem_type", "")
if problem_type == "regression":
return PoolerIdentity()
if problem_type == "single_label_classification":
return PoolerClassify()
if problem_type == "multi_label_classification":
return PoolerMultiLabelClassify()
return PoolerClassify() return PoolerClassify()
...@@ -191,11 +186,18 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): ...@@ -191,11 +186,18 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
fn = resolve_obj_by_qualname(function_name)() fn = resolve_obj_by_qualname(function_name)()
return PoolerActivation.wraps(fn) return PoolerActivation.wraps(fn)
return PoolerScore() return PoolerClassify()
def build_output( def build_output(
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
# Pooling models D2H & synchronize occurs here
if isinstance(all_data, list):
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
else:
all_data = all_data.to("cpu", non_blocking=True)
current_stream().synchronize()
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
return PoolerOutput(outputs=all_outputs) return PoolerOutput(outputs=all_outputs)
...@@ -222,40 +224,21 @@ class PoolingMethod(nn.Module, ABC): ...@@ -222,40 +224,21 @@ class PoolingMethod(nn.Module, ABC):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate() return PoolingParamsUpdate()
@abstractmethod
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Note:
`prompt_len=None` means `prompt_len=len(hidden_states)`.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def forward_all( def forward_all(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
prompt_lens: torch.Tensor, pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
if isinstance(hidden_states, list):
return [
self.forward_one(h, prompt_len)
for h, prompt_len in zip(hidden_states, prompt_lens)
]
return self.forward_all(hidden_states, prompt_lens)
class CLSPool(PoolingMethod): class CLSPool(PoolingMethod):
...@@ -263,24 +246,15 @@ class CLSPool(PoolingMethod): ...@@ -263,24 +246,15 @@ class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"encode", "embed", "classify", "score"}
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with CLS pooling"
return hidden_states[0]
def forward_all( def forward_all(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
prompt_lens: torch.Tensor, pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
first_token_flat_indices = torch.zeros_like(prompt_lens) assert not pooling_cursor.is_partial_prefill(), \
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] "partial prefill not supported with CLS pooling"
return hidden_states[first_token_flat_indices]
return hidden_states[pooling_cursor.first_token_indices_gpu]
class LastPool(PoolingMethod): class LastPool(PoolingMethod):
...@@ -288,20 +262,12 @@ class LastPool(PoolingMethod): ...@@ -288,20 +262,12 @@ class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"encode", "embed", "classify", "score"}
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return hidden_states[-1]
def forward_all( def forward_all(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
prompt_lens: torch.Tensor, pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 return hidden_states[pooling_cursor.last_token_indices_gpu]
return hidden_states[last_token_flat_indices]
class AllPool(PoolingMethod): class AllPool(PoolingMethod):
...@@ -309,22 +275,19 @@ class AllPool(PoolingMethod): ...@@ -309,22 +275,19 @@ class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"} return {"encode"}
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with ALL pooling"
return hidden_states
def forward_all( def forward_all(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
prompt_lens: torch.Tensor, pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
assert not pooling_cursor.is_partial_prefill(), \
"partial prefill not supported with ALL pooling"
hidden_states_lst = list(
hidden_states.split(
pooling_cursor.num_scheduled_tokens_cpu.tolist()))
return [hidden_states_lst[i] for i in pooling_cursor.index]
class MeanPool(PoolingMethod): class MeanPool(PoolingMethod):
...@@ -332,31 +295,25 @@ class MeanPool(PoolingMethod): ...@@ -332,31 +295,25 @@ class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"encode", "embed", "classify", "score"}
def forward_one( def forward_all(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None, pooling_cursor: PoolingCursor,
) -> torch.Tensor: ) -> Union[list[torch.Tensor], torch.Tensor]:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
assert not pooling_cursor.is_partial_prefill(), \
"partial prefill not supported with MEAN pooling" "partial prefill not supported with MEAN pooling"
return hidden_states.mean(dim=0, dtype=torch.float32) prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device,
non_blocking=True)
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]:
# Use float32 for torch.cumsum in MeanPool, # Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly. # otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
start_indices = torch.cat([ start_indices = pooling_cursor.first_token_indices_gpu
torch.tensor([0], device=hidden_states.device), end_indices = pooling_cursor.last_token_indices_gpu
torch.cumsum(prompt_lens[:-1], dim=0) return (cumsum[end_indices] - cumsum[start_indices] +
])
end_indices = torch.cumsum(prompt_lens, dim=0)
return (cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1) hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
...@@ -409,24 +366,37 @@ class PoolerNormalize(PoolerActivation): ...@@ -409,24 +366,37 @@ class PoolerNormalize(PoolerActivation):
return x.to(pooled_data.dtype) return x.to(pooled_data.dtype)
class PoolerClassify(PoolerActivation): class PoolerMultiLabelClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
num_labels = pooled_data.shape[-1]
if num_labels < 2:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
class PoolerClassify(PoolerActivation):
class PoolerScore(PoolerActivation): def __init__(self, *, static_num_labels: bool = True) -> None:
super().__init__()
if static_num_labels:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.num_labels = getattr(vllm_config.model_config.hf_config,
"num_labels", 0)
if self.num_labels == 0:
logger.warning("num_labels should be > 0 for classification"
"models, falling back to softmax. "
"Please check if the configuration is correct.")
else:
self.num_labels = None
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
num_labels = pooled_data.shape[-1] num_labels = (self.num_labels if self.num_labels is not None else
pooled_data.shape[-1])
if num_labels < 2: if num_labels < 2:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
return pooled_data return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
class LambdaPoolerActivation(PoolerActivation): class LambdaPoolerActivation(PoolerActivation):
...@@ -457,9 +427,33 @@ class EmbeddingPoolerHead(PoolerHead): ...@@ -457,9 +427,33 @@ class EmbeddingPoolerHead(PoolerHead):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(activation=PoolerNormalize()) super().__init__(activation=PoolerNormalize())
# Load ST projector if available
from vllm.config import get_current_vllm_config
from vllm.model_executor.models.adapters import _load_st_projector
vllm_config = get_current_vllm_config()
self.projector = _load_st_projector(
vllm_config.model_config) if vllm_config else None
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata): pooling_metadata: PoolingMetadata):
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
# Apply ST projector
if self.projector is not None:
projector = cast(nn.Module, self.projector)
def _proj(x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
y = projector(x.to(torch.float32))
return y.to(orig_dtype)
pooled_data = _proj(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params = get_pooling_params(pooling_metadata) pooling_params = get_pooling_params(pooling_metadata)
# for matryoshka representation # for matryoshka representation
...@@ -491,13 +485,14 @@ class EmbeddingPoolerHead(PoolerHead): ...@@ -491,13 +485,14 @@ class EmbeddingPoolerHead(PoolerHead):
for vecs, f in zip(pooled_data, flags) for vecs, f in zip(pooled_data, flags)
] ]
# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data return pooled_data
class RewardPoolerHead(PoolerHead): class RewardPoolerHead(PoolerHead):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(activation=PoolerClassify()) super().__init__(activation=PoolerClassify(static_num_labels=False))
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata): pooling_metadata: PoolingMetadata):
...@@ -651,15 +646,13 @@ class ClassifierPooler(Pooler): ...@@ -651,15 +646,13 @@ class ClassifierPooler(Pooler):
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
if self.classifier is not None: if self.classifier is not None:
# apply classifier once on the full batch if possible
if isinstance(pooled_data, torch.Tensor):
pooled_data = self.classifier(pooled_data) pooled_data = self.classifier(pooled_data)
elif len({data.shape for data in pooled_data}) <= 1: # pooled_data shape: [batchsize, num_labels]
pooled_data = self.classifier(torch.stack(pooled_data))
else:
pooled_data = [self.classifier(data) for data in pooled_data]
pooling_params = get_pooling_params(pooling_metadata) pooling_params = get_pooling_params(pooling_metadata)
flags = [p.activation for p in pooling_params] flags = [p.activation for p in pooling_params]
...@@ -672,6 +665,7 @@ class ClassifierPooler(Pooler): ...@@ -672,6 +665,7 @@ class ClassifierPooler(Pooler):
for vecs, f in zip(pooled_data, flags) for vecs, f in zip(pooled_data, flags)
] ]
# scores shape: [batchsize, num_labels]
return build_output(scores) return build_output(scores)
...@@ -702,12 +696,6 @@ class DispatchPooler(Pooler): ...@@ -702,12 +696,6 @@ class DispatchPooler(Pooler):
) -> PoolerOutput: ) -> PoolerOutput:
poolers_by_task = self.poolers_by_task poolers_by_task = self.poolers_by_task
if isinstance(hidden_states, list):
hidden_states_lst = hidden_states
else:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
outputs = list[PoolingSequenceGroupOutput]() outputs = list[PoolingSequenceGroupOutput]()
offset = 0 offset = 0
for task, group in groupby(get_tasks(pooling_metadata)): for task, group in groupby(get_tasks(pooling_metadata)):
...@@ -718,7 +706,7 @@ class DispatchPooler(Pooler): ...@@ -718,7 +706,7 @@ class DispatchPooler(Pooler):
num_items = len(list(group)) num_items = len(list(group))
group_output: PoolerOutput = pooler( group_output: PoolerOutput = pooler(
hidden_states_lst[offset:offset + num_items], hidden_states,
pooling_metadata[offset:offset + num_items], pooling_metadata[offset:offset + num_items],
) )
......
...@@ -15,7 +15,6 @@ QuantizationMethods = Literal[ ...@@ -15,7 +15,6 @@ QuantizationMethods = Literal[
"fbgemm_fp8", "fbgemm_fp8",
"modelopt", "modelopt",
"modelopt_fp4", "modelopt_fp4",
"marlin",
"bitblas", "bitblas",
"gguf", "gguf",
"gptq_marlin_24", "gptq_marlin_24",
...@@ -25,7 +24,6 @@ QuantizationMethods = Literal[ ...@@ -25,7 +24,6 @@ QuantizationMethods = Literal[
"gptq", "gptq",
"compressed-tensors", "compressed-tensors",
"bitsandbytes", "bitsandbytes",
"qqq",
"hqq", "hqq",
"experts_int8", "experts_int8",
"neuron_quant", "neuron_quant",
...@@ -37,6 +35,7 @@ QuantizationMethods = Literal[ ...@@ -37,6 +35,7 @@ QuantizationMethods = Literal[
"rtn", "rtn",
"inc", "inc",
"mxfp4", "mxfp4",
"petit_nvfp4",
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -106,13 +105,12 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -106,13 +105,12 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .hqq_marlin import HQQMarlinConfig from .hqq_marlin import HQQMarlinConfig
from .inc import INCConfig from .inc import INCConfig
from .ipex_quant import IPEXConfig from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config from .mxfp4 import Mxfp4Config
from .neuron_quant import NeuronQuantConfig from .neuron_quant import NeuronQuantConfig
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
from .rtn import RTNConfig from .rtn import RTNConfig
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig from .tpu_int8 import Int8TpuConfig
...@@ -125,7 +123,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -125,7 +123,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config, "modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config, "modelopt_fp4": ModelOptNvFp4Config,
"marlin": MarlinConfig,
"bitblas": BitBLASConfig, "bitblas": BitBLASConfig,
"gguf": GGUFConfig, "gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
...@@ -136,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -136,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"ptpc_fp8": PTPCFp8Config, "ptpc_fp8": PTPCFp8Config,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig, "hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config, "experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig, "neuron_quant": NeuronQuantConfig,
...@@ -148,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -148,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"rtn": RTNConfig, "rtn": RTNConfig,
"inc": INCConfig, "inc": INCConfig,
"mxfp4": Mxfp4Config, "mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
...@@ -497,6 +497,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -497,6 +497,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
...@@ -523,6 +524,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -523,6 +524,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
......
...@@ -466,6 +466,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -466,6 +466,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
...@@ -490,6 +491,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -490,6 +491,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
if self.quant_config.load_in_8bit: if self.quant_config.load_in_8bit:
......
...@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat, ...@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
from compressed_tensors.quantization import (QuantizationArgs, from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy, QuantizationStrategy,
QuantizationType) QuantizationType)
from compressed_tensors.transform import TransformConfig
from pydantic import BaseModel from pydantic import BaseModel
import vllm.envs as envs import vllm.envs as envs
...@@ -26,10 +27,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso ...@@ -26,10 +27,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A4Fp4, CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16) CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod, get_linear_transform_schemes)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format, find_matched_target, is_activation_quantization_format,
should_ignore_layer) should_ignore_layer)
...@@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list: list[str], sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None, kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None,
transform_config: Optional[TransformConfig] = None,
): ):
super().__init__() super().__init__()
self.ignore = ignore self.ignore = ignore
...@@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_ignore_list = sparsity_ignore_list self.sparsity_ignore_list = sparsity_ignore_list
self.config = config self.config = config
if transform_config is not None:
self.transform_config = TransformConfig.model_validate(
transform_config)
else:
self.transform_config = None
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
...@@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix) # collect schemes
if scheme is None: quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
return UnquantizedLinearMethod() input_tfms, output_tfms = get_linear_transform_schemes(
layer.scheme = scheme layer, prefix, self.transform_config,
return CompressedTensorsLinearMethod(self) self.packed_modules_mapping)
# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
quant_method = CompressedTensorsLinearMethod(self)
# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, input_tfms, output_tfms)
else:
return quant_method
if isinstance(layer, Attention): if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self) return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
...@@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
config=config) config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config) config=config)
transform_config = config.get("transform_config")
return cls( return cls(
target_scheme_map=target_scheme_map, target_scheme_map=target_scheme_map,
...@@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map, sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list, sparsity_ignore_list=sparsity_ignore_list,
config=config, config=config,
transform_config=transform_config,
) )
@classmethod @classmethod
...@@ -200,8 +221,10 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -200,8 +221,10 @@ class CompressedTensorsConfig(QuantizationConfig):
format format
) if format is not None else is_activation_quantization_format( ) if format is not None else is_activation_quantization_format(
quant_format) quant_format)
if act_quant_format: # TODO(czhu): w4a8fp8 is in packed-quantized format
# but needs input activation quantization
input_activations = quant_config.get("input_activations") input_activations = quant_config.get("input_activations")
if act_quant_format or input_activations:
# The only case where we have activation quant supported # The only case where we have activation quant supported
# but no input_activations provided in the config # but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where # should be w8a16fp8 w8a16fp8 can also run for cases where
...@@ -352,6 +375,28 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -352,6 +375,28 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant.strategy == QuantizationStrategy.TENSOR) input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w4a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
if not weight_quant or not input_quant:
return False
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
# Only per-group symmetric weight (4bit)
# + per-tok symmetric activation (8bit) quantization supported.
return (is_weight_4_bits and is_activation_8_bits and is_token
and is_symmetric and is_dynamic)
def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w4a8(weight_quant, input_quant))
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True) return (self._check_scheme_supported(90, error=False, match_exact=True)
...@@ -401,19 +446,30 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -401,19 +446,30 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant: BaseModel, weight_quant: BaseModel,
input_quant: BaseModel, input_quant: BaseModel,
format: Optional[str] = None) -> "CompressedTensorsScheme": format: Optional[str] = None) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
# Detect If Mixed Precision # Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant): if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4() return CompressedTensorsW4A16Fp4()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value if (format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
assert weight_quant.symmetric assert weight_quant.symmetric
return CompressedTensorsW4A16Sparse24( return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size) group_size=weight_quant.group_size)
if (self.quant_format == CompressionFormat.pack_quantized.value if (format == CompressionFormat.pack_quantized.value
and weight_quant.num_bits in WNA16_SUPPORTED_BITS): and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
return CompressedTensorsWNA16( return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
...@@ -422,10 +478,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -422,10 +478,7 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size=weight_quant.group_size, group_size=weight_quant.group_size,
actorder=weight_quant.actorder) actorder=weight_quant.actorder)
act_quant_format = is_activation_quantization_format( act_quant_format = is_activation_quantization_format(format)
format
) if format is not None else is_activation_quantization_format(
self.quant_format)
if act_quant_format: if act_quant_format:
if self._is_fp4a4_nvfp4(weight_quant, input_quant): if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if cutlass_fp4_supported( if cutlass_fp4_supported(
...@@ -505,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -505,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config # Find the "target" in the compressed-tensors config
# that our layer conforms to. # that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep # TODO (@kylesayrs): support ignore module names with ct matching utils
# so we do not have to re-write these functions if should_ignore_layer(layer_name,
# need to make accelerate optional in ct to do this ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None
# Will be empty for models with only sparsity # Will be empty for models with only sparsity
weight_quant = input_quant = None weight_quant = input_quant = None
...@@ -524,7 +579,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -524,7 +579,7 @@ class CompressedTensorsConfig(QuantizationConfig):
format = scheme_dict.get("format") format = scheme_dict.get("format")
# Find the sparsity scheme of the layer # Find the sparsity scheme of the layer
# assume that fused layers inerhit first component's sparsity scheme # assume that fused layers inherit first component's sparsity scheme
sparsity_targets = (self.sparsity_scheme_map.keys() - sparsity_targets = (self.sparsity_scheme_map.keys() -
set(self.sparsity_ignore_list)) set(self.sparsity_ignore_list))
sparsity_scheme: Optional[SparsityCompressionConfig] = None sparsity_scheme: Optional[SparsityCompressionConfig] = None
...@@ -690,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -690,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer input. See LinearMethodBase for param details layer input. See LinearMethodBase for param details
""" """
scheme = layer.scheme scheme = layer.scheme
if scheme is None: if scheme is None:
raise ValueError("A scheme must be defined for each layer") raise ValueError("A scheme must be defined for each layer")
......
...@@ -22,6 +22,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( ...@@ -22,6 +22,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe) is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target)
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
...@@ -65,12 +67,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -65,12 +67,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module, layer: torch.nn.Module
) -> "CompressedTensorsMoEMethod": ) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights") # Check if a using "Linear" to select schemes
input_quant = quant_config.target_scheme_map["Linear"].get( if "Linear" in quant_config.target_scheme_map:
matched_target = "Linear"
else:
# May have instead defined the linear layers in the fused model
fused_layers = [
"re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"
]
current_scheme = None
for fused_layer in fused_layers:
# Check if one of the fused layers are defined in quant_config
matched_target = find_matched_target(
layer_name=fused_layer,
module=layer,
targets=quant_config.target_scheme_map.keys(),
fused_mapping=quant_config.packed_modules_mapping)
# Only valid if down_proj, gate_proj, and up_proj
# are mapped to the same quant scheme in the quant_config
if current_scheme is None:
current_scheme = quant_config.target_scheme_map.get(
matched_target)
else:
assert current_scheme == quant_config.target_scheme_map.get(
matched_target)
weight_quant = quant_config.target_scheme_map[matched_target].get(
"weights")
input_quant = quant_config.target_scheme_map[matched_target].get(
"input_activations") "input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
...@@ -246,11 +276,11 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -246,11 +276,11 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return return
# swizzle weight scales # swizzle weight scales
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
layer.w13_weight_scale), layer.w13_weight_scale),
requires_grad=False) requires_grad=False)
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
layer.w2_weight_scale), layer.w2_weight_scale),
requires_grad=False) requires_grad=False)
...@@ -292,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -292,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation.""" """Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl( experts = select_nvfp4_gemm_impl(
...@@ -319,6 +350,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -319,6 +350,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
...@@ -344,6 +376,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -344,6 +376,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype, indices_type=self.topk_indices_dtype,
) )
...@@ -383,8 +416,35 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -383,8 +416,35 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_blockscale_swizzled, w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.allow_flashinfer:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4)
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -400,8 +460,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -400,8 +460,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
a=x, a=x,
w1_fp4=layer.w13_weight, w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight, w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled, w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_blockscale_swizzled, w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas, g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas, g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant, a1_gscale=layer.w13_input_scale_quant,
...@@ -642,11 +702,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -642,11 +702,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts self.fused_experts_func = fused_experts
if self.use_cutlass:
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full(
(layer.local_num_experts, ),
layer.hidden_size,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full(
(layer.local_num_experts, ),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full(
(layer.local_num_experts, ),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
def select_gemm_impl( def select_gemm_impl(
self, self, prepare_finalize: FusedMoEPrepareAndFinalize,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig, moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute: layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path # cutlass path
if self.use_cutlass: if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
...@@ -666,6 +744,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -666,6 +744,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe.in_dtype, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN, self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
) )
else: else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
...@@ -673,6 +755,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -673,6 +755,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe.in_dtype, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN, self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
) )
self.disable_expert_map = (num_dispatchers > 1 self.disable_expert_map = (num_dispatchers > 1
...@@ -725,6 +811,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -725,6 +811,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
...@@ -748,6 +835,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -748,6 +835,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype, indices_type=self.topk_indices_dtype,
) )
...@@ -795,6 +883,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -795,6 +883,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
...@@ -969,6 +1061,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -969,6 +1061,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
...@@ -996,6 +1089,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -996,6 +1089,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
...@@ -1273,6 +1367,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1273,6 +1367,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
...@@ -1301,6 +1396,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1301,6 +1396,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
...@@ -1504,6 +1600,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1504,6 +1600,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
...@@ -1530,6 +1627,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1530,6 +1627,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype) indices_type=self.topk_indices_dtype)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24) CompressedTensorsW4A16Sparse24)
...@@ -21,5 +22,6 @@ __all__ = [ ...@@ -21,5 +22,6 @@ __all__ = [
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensorsW4A16Fp4", "CompressedTensors24", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int" "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8"
] ]
...@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations) run_nvfp4_emulations)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale)
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
...@@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
weight_loader=weight_loader) weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale) layer.register_parameter("input_global_scale", input_global_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer) -> None: def process_weights_after_loading(self, layer) -> None:
global_input_scale = layer.input_global_scale.max().to(torch.float32) global_input_scale = layer.input_global_scale.max().to(torch.float32)
...@@ -133,12 +112,11 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -133,12 +112,11 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
torch.uint8), epilogue_tile_m).reshape( torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn)) weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale_swizzled = Parameter(weight_scale, layer.weight_scale = Parameter(weight_scale, requires_grad=False)
requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False) layer.weight_packed = Parameter(weight, requires_grad=False)
else: else:
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, layer.weight_scale = Parameter(swizzled_weight_scale,
requires_grad=False) requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data, layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False) requires_grad=False)
...@@ -157,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -157,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x=x, x=x,
input_global_scale=layer.input_global_scale, input_global_scale=layer.input_global_scale,
weight=layer.weight_packed, weight=layer.weight_packed,
weight_scale_swizzled=layer.weight_scale_swizzled, weight_scale_swizzled=layer.weight_scale,
weight_global_scale=layer.weight_global_scale) weight_global_scale=layer.weight_global_scale)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
...@@ -170,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -170,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
mm_args = (x_fp4, layer.weight_packed, x_blockscale, mm_args = (x_fp4, layer.weight_packed, x_blockscale,
layer.weight_scale_swizzled, layer.alpha, output_dtype) layer.weight_scale, layer.alpha, output_dtype)
if self.backend == "flashinfer-trtllm": if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass": elif self.backend == "flashinfer-cutlass":
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
# yapf: enable
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A8Fp8"]
W4A8_SUPPORTED_TYPES_MAP = {
4: scalar_types.int4,
}
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size != 128 or self.strategy != "group":
raise ValueError("W4A8 kernels require group quantization " \
"with group size 128")
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
@classmethod
def get_min_capability(cls) -> int:
# hopper
return 90
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
group_size=self.group_size,
zero_points=not self.symmetric,
has_g_idx=self.has_g_idx,
out_type=params_dtype
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW4A8Fp8",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition //
self.pack_factor,
dtype=torch.int32,
))
# TODO(czhu): allocate the packed fp8 scales memory here?
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=torch.float8_e4m3fn,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
# per-channel scales
weight_chan_scale = ChannelQuantScaleParameter(
data=torch.empty((output_size_per_partition, 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("weight_chan_scale", weight_chan_scale)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment