Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
@dataclass
class Mamba2Metadata:
has_prefill: bool
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def prepare_mamba2_metadata(
chunk_size: int,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Mamba2Metadata:
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states = None
prep_initial_states = False
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# precompute flag to avoid device syncs later in mamba2 forwards
prep_initial_states = torch.any(has_initial_states).item()
has_prefill = attn_metadata.num_prefills > 0
seq_idx = None
chunk_indices, chunk_offsets = None, None
if has_prefill:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
# compute metadata for chunked prefill.
# actually this is only needed if there are initial states,
# but this is determinable only from attention metadata yet
# unavailable from the top-level model forward. Rather than
# complicating things to extract said metadata, we simply just
# compute them once at the top level model forward and reuse
# them in mamba layers. If not needed, they will be ignored
# inside mamba kernels.
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
return Mamba2Metadata(has_prefill=has_prefill,
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets)
......@@ -6,10 +6,6 @@ import torch
from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
......@@ -18,6 +14,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
......@@ -221,7 +218,6 @@ class MambaMixer2(CustomOp):
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation="silu",
chunk_size: int = 256,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
......@@ -257,7 +253,6 @@ class MambaMixer2(CustomOp):
self.ssm_state_size = ssm_state_size
self.activation = activation
self.chunk_size = chunk_size
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
......@@ -388,25 +383,17 @@ class MambaMixer2(CustomOp):
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
mamba2_metadata: Mamba2Metadata,
):
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# are the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
# detect if there are prefills
has_prefill = attn_metadata.num_prefills > 0
# - also need flags to indicate if there are initial states
# - currently we really only support the FlashAttention backend
has_initial_states = None
if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
gate, hidden_states_B_C, dt = torch.split(
......@@ -423,7 +410,7 @@ class MambaMixer2(CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if has_prefill:
if mamba2_metadata.has_prefill:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
......@@ -439,7 +426,7 @@ class MambaMixer2(CustomOp):
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=has_initial_states,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc).transpose(
0, 1)[:seq_len]
......@@ -467,16 +454,15 @@ class MambaMixer2(CustomOp):
)
# 3. State Space Model sequence transformation
if has_prefill:
if mamba2_metadata.has_prefill:
initial_states = None
if has_initial_states is not None and torch.any(
has_initial_states):
zero_init_indices = mamba_cache_params.state_indices_tensor[
~has_initial_states]
mamba_cache_params.ssm_state[zero_init_indices] = 0
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
......@@ -485,11 +471,13 @@ class MambaMixer2(CustomOp):
self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
chunk_size=self.chunk_size,
chunk_size=mamba2_metadata.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=sequence_idx,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc,
initial_states=initial_states,
return_varlen_states=True,
......
......@@ -5,8 +5,6 @@
# ruff: noqa: E501,SIM102
import math
import torch
import triton
import triton.language as tl
......@@ -442,40 +440,6 @@ def _chunk_scan_fwd_kernel(
(offs_out_n[None, :] < hdim))
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def _chunk_scan_fwd(
cb,
x,
......@@ -486,6 +450,8 @@ def _chunk_scan_fwd(
D=None,
z=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
initial_states=None,
):
batch, seqlen, nheads, headdim = x.shape
......@@ -502,7 +468,6 @@ def _chunk_scan_fwd(
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
chunk_indices, chunk_offsets = None, None
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
......@@ -510,15 +475,19 @@ def _chunk_scan_fwd(
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert batch == 1, "chunk scan only supports initial states with batch 1"
assert initial_states.shape == (seq_idx[0].max() + 1, nheads,
headdim, dstate)
if initial_states.shape[0] == 1:
# no in this case no point to use initial states
initial_states = None
else:
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
assert chunk_indices is not None and chunk_offsets is not None, \
(
"chunk_indices and chunk_offsets should have been set"
)
else:
chunk_indices, chunk_offsets = None, None
else:
chunk_indices, chunk_offsets = None, None
# Allocates output.
out = torch.empty(batch,
......
......@@ -30,6 +30,8 @@ def _mamba_chunk_scan_combined_fwd(x,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
......@@ -96,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x,
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
# ii) seq_idx and iii) is_cont_batched to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - this will ensure that states will be updated with the rightmost flushed seq_idx
......@@ -141,6 +143,8 @@ def _mamba_chunk_scan_combined_fwd(x,
D=D,
z=z,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=initial_states,
)
if cu_seqlens is None:
......@@ -170,6 +174,8 @@ def mamba_chunk_scan_combined(x,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
......@@ -210,6 +216,8 @@ def mamba_chunk_scan_combined(x,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
......
......@@ -150,8 +150,6 @@ def _state_passing_fwd(
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert initial_states.shape == (seq_idx.max().item() + 1, nheads,
dim)
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
......
......@@ -97,7 +97,7 @@ class SimplePooler(nn.Module):
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data)
pooled_data = self.head(pooled_data, pooling_metadata)
pooled_outputs = [self.build_output(data) for data in pooled_data]
return PoolerOutput(outputs=pooled_outputs)
......@@ -217,14 +217,28 @@ class PoolerHead(nn.Module):
self.normalize = normalize
self.softmax = softmax
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
dimensions_list = [
pooling_param.dimensions
for _, pooling_param in pooling_metadata.seq_groups
]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
pooled_data = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
]
if self.normalize:
if isinstance(pooled_data, list):
pooled_data = [
F.normalize(data, p=2, dim=1) for data in pooled_data
F.normalize(data, p=2, dim=-1) for data in pooled_data
]
else:
pooled_data = F.normalize(pooled_data, p=2, dim=1)
pooled_data = F.normalize(pooled_data, p=2, dim=-1)
if self.softmax:
if isinstance(pooled_data, list):
......
......@@ -31,7 +31,8 @@ QUANTIZATION_METHODS: List[str] = [
"neuron_quant",
"ipex",
"quark",
"moe_wna16"
"moe_wna16",
"torchao",
]
# The customized quantization methods which will be added to this dict.
......@@ -103,6 +104,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: Dict[str, Type[QuantizationConfig]] = {
......@@ -132,6 +134,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"ipex": IPEXConfig,
"quark": QuarkConfig,
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
......@@ -96,8 +96,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(
self, layer.activation, layer.expert_map)
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
return None
@classmethod
......
......@@ -6,7 +6,8 @@ from typing import Callable, List, Optional
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
......@@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
__all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsWNA16MoEMethod"
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
]
......@@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
activation: str,
expert_map: Optional[torch.Tensor],
layer: torch.nn.Module,
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
......@@ -51,9 +53,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
return CompressedTensorsWNA16MoEMethod(quant_config)
# Prefer to use the non-marlin kernel when:
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
# 3. Actorder is not group/dynamic (g_idx is unsupported)
# 4. Scaled are grouped (channelwise is unsupported)
if ((layer.local_num_experts >= 16
or layer.params_dtype != torch.float16) and
weight_quant.actorder not in (ActivationOrdering.GROUP,
ActivationOrdering.DYNAMIC)
and weight_quant.strategy in QuantizationStrategy.GROUP):
return CompressedTensorsWNA16MoEMethod(quant_config)
else:
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and activation == "silu" and expert_map is None):
and layer.activation == "silu" and layer.expert_map is None):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
......@@ -74,14 +88,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy
== QuantizationStrategy.TENSOR)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not (per_tensor or per_channel):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales "
"for weights and activations are supported. Found "
"For FP8 Fused MoE layers, we require per tensor "
"or channelwise, dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
......@@ -109,24 +132,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
......@@ -149,6 +188,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
......@@ -190,24 +230,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def apply(
self,
......@@ -251,6 +292,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
......@@ -482,7 +525,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
......@@ -823,3 +866,215 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
is_k_full=self.is_k_full)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
# channelwise is not supported by this kernel
assert config.strategy == "group"
self.group_size = config.group_size
# grouped actorder isn't supported by this kernel
assert config.actorder != "group"
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")
if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_scales_size = intermediate_size_per_partition
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": False})
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Reconfigure packed weights and scales to match moe_wna16 format
layer.w13_weight_packed = torch.nn.Parameter(
layer.w13_weight_packed.transpose(1, 2).contiguous().view(
torch.uint8),
requires_grad=False)
layer.w2_weight_packed = torch.nn.Parameter(
layer.w2_weight_packed.transpose(1,
2).contiguous().view(torch.uint8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
layer.w13_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
layer.w2_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=self.num_bits == 4,
use_int8_w8a16=self.num_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size])
......@@ -116,7 +116,9 @@ class Fp8Config(QuantizationConfig):
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
......@@ -252,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
else:
assert self.quant_config.activation_scheme == "dynamic"
......@@ -266,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
......@@ -276,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
......
......@@ -29,7 +29,7 @@ def choose_scaled_mm_linear_kernel(
compute_capability: Optional[int] = None
) -> Type[ScaledMMLinearKernel]:
"""
Choose an ScalledMMLinearKernel that can implement the given config for the
Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
......
......@@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.out_dtype = torch.get_default_dtype()
@classmethod
......@@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme):
)
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
......@@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme):
weight = layer.weight
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
......@@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# the newly added parameters
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
......
......@@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme):
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"),
......@@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme):
# WEIGHT SCALE
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
ChannelQuantZPParameter = ChannelQuantScaleParameter
weight_zero_point = ChannelQuantZPParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.int8),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
PerTensorZPParameter = PerTensorScaleParameter
weight_zero_point = PerTensorZPParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
if self.is_static_input_scheme:
......@@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme):
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: quark stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
input_zero_point = BasevLLMParameter(data=torch.empty(
1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
......@@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme):
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.register_parameter("weight_zero_point", None)
delattr(layer, 'weight_zero_point')
if self.input_symmetric:
layer.register_parameter("input_zero_point", None)
delattr(layer, 'input_zero_point')
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""
def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
def get_name(self) -> str:
return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@staticmethod
def get_config_filenames() -> List[str]:
return ["config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
"""Create the quant config from an hf model config"""
try:
from torchao.core.config import config_from_dict
except ImportError as err:
raise ImportError(
"Please install torchao>=0.10.0 via "
"`pip install torchao>=0.10.0` to use torchao quantization."
) from err
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
assert hf_config is not None, "quant_type must be specified"
assert (len(hf_config) == 1 and "default" in hf_config
), "Expected only one key 'default' in quant_type dictionary"
quant_type = hf_config["default"]
ao_config = config_from_dict(quant_type)
return cls(ao_config)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["TorchAOLinearMethod"]:
if isinstance(layer, LinearBase):
return TorchAOLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def torchao_quantize_param_data(param: torch.Tensor,
torchao_config: Any) -> torch.nn.Parameter:
"""Quantize a Tensor with torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to
use to quantize the Tensor
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
assert isinstance(torchao_config, AOBaseConfig)
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight
class TorchAOLinearMethod(LinearMethodBase):
"""Linear method for torchao.
Args:
torchao_config: The torchao quantization config, a string
that encodes the type of quantization and all relevant arguments.
"""
def __init__(self, quant_config: TorchAOConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
weight = torchao_quantize_param_data(weight,
self.quant_config.torchao_config)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/blob/4cb53ecd0cffceb6dee5c011a58f65997a86f151/python/sglang/srt/layers/quantization/int8_kernel.py
import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
output = w8a8_block_int8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def input_to_int8(
x: torch.Tensor,
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to int8 values with
tensor-wise quantization."""
iinfo = torch.iinfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
int8_min, int8_max = iinfo.min, iinfo.max
scale = int8_max / amax
x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] *= x_s[j][i]
return x_dq_block
@triton.jit
def _per_token_quant_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_int8(x):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_int8[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_group_quant_int8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Columns of input
N,
# Avoid to divide zero
eps,
# Information for int8
int8_min,
int8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / int8_max
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_int8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_max = iinfo.max
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size, ),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M, )](
x,
x_q,
x_s,
group_size,
N,
eps,
int8_min=int8_min,
int8_max=int8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
@triton.jit
def _w8a8_block_int8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:,
None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block INT8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
("Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"),
config_file_path,
)
return None
def w8a8_block_int8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be
2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_int8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
......@@ -305,7 +305,7 @@ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
return max(m, 64) * n < 64 * 2048 and k >= 2048
return n < 2048 and k >= 2048
def apply_gptq_marlin_linear(
......
......@@ -33,11 +33,15 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.model_loader.tensorizer import (
......@@ -111,10 +115,12 @@ def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
model_class: Optional[type[nn.Module]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
if model_class is None:
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
......@@ -158,6 +164,11 @@ def _initialize_model(
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
for _, module in model.named_modules():
if isinstance(module, QKVCrossParallelLinear):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module.process_weights_after_loading()
continue
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
......@@ -403,7 +414,7 @@ class DefaultModelLoader(BaseModelLoader):
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
def _get_all_weights(
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
......@@ -442,7 +453,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
......
......@@ -174,8 +174,39 @@ def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
from transformers_neuronx.config import NeuronConfig
from transformers_neuronx.config import (ContinuousBatchingConfig,
GenerationConfig,
KVCacheQuantizationConfig,
NeuronConfig, QuantizationConfig,
SparseAttnConfig)
overridden_neuron_config = overridden_neuron_config or {}
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
if sparse_attn:
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
**sparse_attn)
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
if kv_cache_quant:
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
**kv_cache_quant)
continuous_batching = overridden_neuron_config.pop("continuous_batching",
{})
if continuous_batching:
overridden_neuron_config[
"continuous_batching"] = ContinuousBatchingConfig(
**continuous_batching)
quant = overridden_neuron_config.pop("quant", {})
if quant:
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
on_device_generation = overridden_neuron_config.pop(
"on_device_generation", {})
if on_device_generation:
overridden_neuron_config["on_device_generation"] = GenerationConfig(
**on_device_generation)
default_neuron_config.update(overridden_neuron_config)
return NeuronConfig(**default_neuron_config)
......
......@@ -658,8 +658,21 @@ def initialize_dummy_weights(
for param in model.state_dict().values():
if torch.is_floating_point(param):
if current_platform.is_tpu():
# XLA device does not support torch.Generator()
param.uniform_(low, high)
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
# Note: The param.uniform_ function cannot be used in this
# context because it demands more TPU HBM than directly copying
# from a CPU tensor.
# Note: We avoid using torch.rank_like as it doesn't currently
# support the generator argument.
param.copy_((high - low) *
torch.rand(*param.shape,
generator=generator,
dtype=param.dtype,
layout=param.layout,
requires_grad=param.requires_grad,
device="cpu") + low)
torch._sync(param)
continue
generator = torch.Generator(device=param.data.device)
......
......@@ -21,12 +21,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
# yapf: disable
......@@ -408,13 +409,6 @@ class AriaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
......@@ -422,31 +416,31 @@ class AriaProcessingInfo(BaseProcessingInfo):
class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token: str = processor.tokenizer.image_token # type: ignore
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
vision_config = self.info.get_vision_config()
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
return {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.tokenizer.image_token # type: ignore
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
......@@ -605,6 +599,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment