Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
......@@ -461,7 +461,7 @@ class FusedMoE(torch.nn.Module):
# Use expert parallelism instead of tensor parallelism?
vllm_config = get_current_vllm_config()
use_ep = (vllm_config.parallel_config.enable_expert_parallel
and self.tp_size > 1)
and self.tp_size * self.dp_size > 1)
# For smuggling this layer into the fused moe custom op
self.use_direct_call = self.dp_size == 1
......@@ -545,7 +545,9 @@ class FusedMoE(torch.nn.Module):
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod")):
......@@ -697,9 +699,10 @@ class FusedMoE(torch.nn.Module):
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
if self.quant_method.__class__.__name__ in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod"):
loaded_weight = loaded_weight.t().contiguous()
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
......
......@@ -1376,6 +1376,7 @@ class QKVCrossParallelLinear(LinearBase):
prefix=f"{prefix}.kv_proj_encoder")
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.q_size = self.q_proj_decoder.output_size_per_partition
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
if bias:
......@@ -1387,20 +1388,31 @@ class QKVCrossParallelLinear(LinearBase):
else:
self.bias = None
def process_weights_after_loading(self):
for layer in self.proj.values():
if self.quant_method is not None:
self.quant_method.process_weights_after_loading(layer)
@property
def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
target_param = getattr(layer, name, None)
if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="q_proj_decoder")
return layer
@property
def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
target_param = getattr(layer, name, None)
if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="kv_proj_encoder")
return layer
def sync_weight_attrs(
......@@ -1489,11 +1501,14 @@ class QKVCrossParallelLinear(LinearBase):
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
else:
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
s += f", q_size={self.q_size}"
s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
......
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
@dataclass
class Mamba2Metadata:
has_prefill: bool
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def prepare_mamba2_metadata(
chunk_size: int,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Mamba2Metadata:
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states = None
prep_initial_states = False
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# precompute flag to avoid device syncs later in mamba2 forwards
prep_initial_states = torch.any(has_initial_states).item()
has_prefill = attn_metadata.num_prefills > 0
seq_idx = None
chunk_indices, chunk_offsets = None, None
if has_prefill:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
# compute metadata for chunked prefill.
# actually this is only needed if there are initial states,
# but this is determinable only from attention metadata yet
# unavailable from the top-level model forward. Rather than
# complicating things to extract said metadata, we simply just
# compute them once at the top level model forward and reuse
# them in mamba layers. If not needed, they will be ignored
# inside mamba kernels.
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
return Mamba2Metadata(has_prefill=has_prefill,
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets)
......@@ -6,10 +6,6 @@ import torch
from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
......@@ -18,6 +14,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
......@@ -221,7 +218,6 @@ class MambaMixer2(CustomOp):
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation="silu",
chunk_size: int = 256,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
......@@ -257,7 +253,6 @@ class MambaMixer2(CustomOp):
self.ssm_state_size = ssm_state_size
self.activation = activation
self.chunk_size = chunk_size
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
......@@ -388,25 +383,17 @@ class MambaMixer2(CustomOp):
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
mamba2_metadata: Mamba2Metadata,
):
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# are the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
# detect if there are prefills
has_prefill = attn_metadata.num_prefills > 0
# - also need flags to indicate if there are initial states
# - currently we really only support the FlashAttention backend
has_initial_states = None
if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
gate, hidden_states_B_C, dt = torch.split(
......@@ -423,7 +410,7 @@ class MambaMixer2(CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if has_prefill:
if mamba2_metadata.has_prefill:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
......@@ -439,7 +426,7 @@ class MambaMixer2(CustomOp):
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=has_initial_states,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc).transpose(
0, 1)[:seq_len]
......@@ -467,16 +454,15 @@ class MambaMixer2(CustomOp):
)
# 3. State Space Model sequence transformation
if has_prefill:
if mamba2_metadata.has_prefill:
initial_states = None
if has_initial_states is not None and torch.any(
has_initial_states):
zero_init_indices = mamba_cache_params.state_indices_tensor[
~has_initial_states]
mamba_cache_params.ssm_state[zero_init_indices] = 0
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
......@@ -485,11 +471,13 @@ class MambaMixer2(CustomOp):
self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
chunk_size=self.chunk_size,
chunk_size=mamba2_metadata.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=sequence_idx,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc,
initial_states=initial_states,
return_varlen_states=True,
......
......@@ -5,8 +5,6 @@
# ruff: noqa: E501,SIM102
import math
import torch
import triton
import triton.language as tl
......@@ -442,40 +440,6 @@ def _chunk_scan_fwd_kernel(
(offs_out_n[None, :] < hdim))
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def _chunk_scan_fwd(
cb,
x,
......@@ -486,6 +450,8 @@ def _chunk_scan_fwd(
D=None,
z=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
initial_states=None,
):
batch, seqlen, nheads, headdim = x.shape
......@@ -502,7 +468,6 @@ def _chunk_scan_fwd(
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
chunk_indices, chunk_offsets = None, None
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
......@@ -510,15 +475,19 @@ def _chunk_scan_fwd(
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert batch == 1, "chunk scan only supports initial states with batch 1"
assert initial_states.shape == (seq_idx[0].max() + 1, nheads,
headdim, dstate)
if initial_states.shape[0] == 1:
# no in this case no point to use initial states
initial_states = None
else:
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
assert chunk_indices is not None and chunk_offsets is not None, \
(
"chunk_indices and chunk_offsets should have been set"
)
else:
chunk_indices, chunk_offsets = None, None
else:
chunk_indices, chunk_offsets = None, None
# Allocates output.
out = torch.empty(batch,
......
......@@ -30,6 +30,8 @@ def _mamba_chunk_scan_combined_fwd(x,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
......@@ -96,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x,
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
# ii) seq_idx and iii) is_cont_batched to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - this will ensure that states will be updated with the rightmost flushed seq_idx
......@@ -141,6 +143,8 @@ def _mamba_chunk_scan_combined_fwd(x,
D=D,
z=z,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=initial_states,
)
if cu_seqlens is None:
......@@ -170,6 +174,8 @@ def mamba_chunk_scan_combined(x,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
......@@ -210,6 +216,8 @@ def mamba_chunk_scan_combined(x,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
......
......@@ -150,8 +150,6 @@ def _state_passing_fwd(
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert initial_states.shape == (seq_idx.max().item() + 1, nheads,
dim)
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
......
......@@ -97,7 +97,7 @@ class SimplePooler(nn.Module):
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data)
pooled_data = self.head(pooled_data, pooling_metadata)
pooled_outputs = [self.build_output(data) for data in pooled_data]
return PoolerOutput(outputs=pooled_outputs)
......@@ -217,14 +217,28 @@ class PoolerHead(nn.Module):
self.normalize = normalize
self.softmax = softmax
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
dimensions_list = [
pooling_param.dimensions
for _, pooling_param in pooling_metadata.seq_groups
]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
pooled_data = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
]
if self.normalize:
if isinstance(pooled_data, list):
pooled_data = [
F.normalize(data, p=2, dim=1) for data in pooled_data
F.normalize(data, p=2, dim=-1) for data in pooled_data
]
else:
pooled_data = F.normalize(pooled_data, p=2, dim=1)
pooled_data = F.normalize(pooled_data, p=2, dim=-1)
if self.softmax:
if isinstance(pooled_data, list):
......
......@@ -32,6 +32,7 @@ QUANTIZATION_METHODS: List[str] = [
"ipex",
"quark",
"moe_wna16",
"torchao",
"blockwise_int8"
]
......@@ -104,6 +105,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
from .blockwise_int8 import BlockInt8Config
......@@ -134,6 +136,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"ipex": IPEXConfig,
"quark": QuarkConfig,
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
"blockwise_int8": BlockInt8Config,
}
# Update the `method_to_config` with customized quantization methods.
......
......@@ -380,10 +380,7 @@ class BlockInt8MoEMethod:
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -417,8 +414,5 @@ class BlockInt8MoEMethod:
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
use_nn_moe=use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert
use_nn_moe=use_nn_moe
)
\ No newline at end of file
......@@ -97,8 +97,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(
self, layer.activation, layer.expert_map)
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
return None
@classmethod
......
......@@ -6,7 +6,8 @@ from typing import Callable, List, Optional
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
......@@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
__all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsWNA16MoEMethod"
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
]
......@@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
activation: str,
expert_map: Optional[torch.Tensor],
layer: torch.nn.Module,
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
......@@ -51,9 +53,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
return CompressedTensorsWNA16MoEMethod(quant_config)
# Prefer to use the non-marlin kernel when:
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
# 3. Actorder is not group/dynamic (g_idx is unsupported)
# 4. Scaled are grouped (channelwise is unsupported)
if ((layer.local_num_experts >= 16
or layer.params_dtype != torch.float16) and
weight_quant.actorder not in (ActivationOrdering.GROUP,
ActivationOrdering.DYNAMIC)
and weight_quant.strategy in QuantizationStrategy.GROUP):
return CompressedTensorsWNA16MoEMethod(quant_config)
else:
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and activation == "silu" and expert_map is None):
and layer.activation == "silu" and layer.expert_map is None):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
......@@ -74,14 +88,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy
== QuantizationStrategy.TENSOR)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not (per_tensor or per_channel):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales "
"for weights and activations are supported. Found "
"For FP8 Fused MoE layers, we require per tensor "
"or channelwise, dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
......@@ -109,24 +132,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
......@@ -149,6 +188,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
......@@ -190,24 +230,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def apply(
self,
......@@ -251,6 +292,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
......@@ -482,7 +525,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
......@@ -823,3 +866,215 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
is_k_full=self.is_k_full)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
# channelwise is not supported by this kernel
assert config.strategy == "group"
self.group_size = config.group_size
# grouped actorder isn't supported by this kernel
assert config.actorder != "group"
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")
if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_scales_size = intermediate_size_per_partition
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": False})
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Reconfigure packed weights and scales to match moe_wna16 format
layer.w13_weight_packed = torch.nn.Parameter(
layer.w13_weight_packed.transpose(1, 2).contiguous().view(
torch.uint8),
requires_grad=False)
layer.w2_weight_packed = torch.nn.Parameter(
layer.w2_weight_packed.transpose(1,
2).contiguous().view(torch.uint8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
layer.w13_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
layer.w2_weight_scale.transpose(1, 2).contiguous(),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=self.num_bits == 4,
use_int8_w8a16=self.num_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size])
......@@ -116,7 +116,9 @@ class Fp8Config(QuantizationConfig):
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
......@@ -252,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
else:
assert self.quant_config.activation_scheme == "dynamic"
......@@ -266,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
......@@ -276,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
......
......@@ -29,7 +29,7 @@ def choose_scaled_mm_linear_kernel(
compute_capability: Optional[int] = None
) -> Type[ScaledMMLinearKernel]:
"""
Choose an ScalledMMLinearKernel that can implement the given config for the
Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
......
......@@ -296,9 +296,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = None,
start_expert: Optional[int] = None,
end_expert: Optional[int] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
......
......@@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.out_dtype = torch.get_default_dtype()
@classmethod
......@@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme):
)
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
......@@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme):
weight = layer.weight
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
......@@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# the newly added parameters
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
......
......@@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme):
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"),
......@@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme):
# WEIGHT SCALE
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
ChannelQuantZPParameter = ChannelQuantScaleParameter
weight_zero_point = ChannelQuantZPParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.int8),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
PerTensorZPParameter = PerTensorScaleParameter
weight_zero_point = PerTensorZPParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
if self.is_static_input_scheme:
......@@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme):
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: quark stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
input_zero_point = BasevLLMParameter(data=torch.empty(
1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
......@@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme):
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.register_parameter("weight_zero_point", None)
delattr(layer, 'weight_zero_point')
if self.input_symmetric:
layer.register_parameter("input_zero_point", None)
delattr(layer, 'input_zero_point')
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""
def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
def get_name(self) -> str:
return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@staticmethod
def get_config_filenames() -> List[str]:
return ["config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
"""Create the quant config from an hf model config"""
try:
from torchao.core.config import config_from_dict
except ImportError as err:
raise ImportError(
"Please install torchao>=0.10.0 via "
"`pip install torchao>=0.10.0` to use torchao quantization."
) from err
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
assert hf_config is not None, "quant_type must be specified"
assert (len(hf_config) == 1 and "default" in hf_config
), "Expected only one key 'default' in quant_type dictionary"
quant_type = hf_config["default"]
ao_config = config_from_dict(quant_type)
return cls(ao_config)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["TorchAOLinearMethod"]:
if isinstance(layer, LinearBase):
return TorchAOLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def torchao_quantize_param_data(param: torch.Tensor,
torchao_config: Any) -> torch.nn.Parameter:
"""Quantize a Tensor with torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to
use to quantize the Tensor
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
assert isinstance(torchao_config, AOBaseConfig)
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight
class TorchAOLinearMethod(LinearMethodBase):
"""Linear method for torchao.
Args:
torchao_config: The torchao quantization config, a string
that encodes the type of quantization and all relevant arguments.
"""
def __init__(self, quant_config: TorchAOConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
weight = torchao_quantize_param_data(weight,
self.quant_config.torchao_config)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/blob/4cb53ecd0cffceb6dee5c011a58f65997a86f151/python/sglang/srt/layers/quantization/int8_kernel.py
import functools
import json
import logging
......@@ -9,9 +10,8 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.utils import W8a8GetCacheJSON
# from sglang.srt.utils import get_device_name
from vllm.utils import W8a8GetCacheJSON
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
......@@ -19,6 +19,75 @@ logger = logging.getLogger(__name__)
W8A8_TRITONJSON=W8a8GetCacheJSON()
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
output = w8a8_block_int8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def input_to_int8(
x: torch.Tensor,
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to int8 values with
tensor-wise quantization."""
iinfo = torch.iinfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
int8_min, int8_max = iinfo.min, iinfo.max
scale = int8_max / amax
x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] *= x_s[j][i]
return x_dq_block
@triton.jit
def _per_token_quant_int8(
x_ptr,
......@@ -29,13 +98,14 @@ def _per_token_quant_int8(
N,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols,
mask=mask, other=0.0).to(tl.float32)
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
......@@ -49,14 +119,15 @@ def per_token_quant_int8(x):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1,),
device=x.device, dtype=torch.float32)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_int8[(M,)](
_per_token_quant_int8[(M, )](
x,
x_q,
scales,
......@@ -79,7 +150,7 @@ def _per_token_group_quant_int8(
y_s_ptr,
# Stride of input
y_stride,
# Collums of input
# Columns of input
N,
# Avoid to divide zero
eps,
......@@ -89,8 +160,9 @@ def _per_token_group_quant_int8(
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
......@@ -119,21 +191,23 @@ def per_token_group_quant_int8(
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
......@@ -144,7 +218,7 @@ def per_token_group_quant_int8(
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
x.shape[:-1] + (x.shape[-1] // group_size, ),
device=x.device,
dtype=torch.float32,
)
......@@ -153,7 +227,7 @@ def per_token_group_quant_int8(
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M,)](
_per_token_group_quant_int8[(M, )](
x,
x_q,
x_s,
......@@ -203,8 +277,8 @@ def _w8a8_block_int8_matmul(
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
......@@ -241,7 +315,8 @@ def _w8a8_block_int8_matmul(
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:,
None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
......@@ -260,9 +335,8 @@ def _w8a8_block_int8_matmul(
@functools.lru_cache
def get_w8a8_block_int8_configs(
N: int, K: int, block_n: int, block_k: int
) -> Optional[Dict[int, Any]]:
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
......@@ -274,11 +348,10 @@ def get_w8a8_block_int8_configs(
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
......@@ -291,10 +364,8 @@ def get_w8a8_block_int8_configs(
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
("Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"),
config_file_path,
)
return None
......@@ -308,17 +379,21 @@ def w8a8_block_int8_matmul(
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""matrix multiplication with block-wise quantization.
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
block_size: The block size for per-block quantization. It should be
2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
......@@ -334,27 +409,26 @@ def w8a8_block_int8_matmul(
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
# #Default config
# #Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
# #print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
# # Default config
# # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# "BLOCK_SIZE_M": 64,
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
#print("W8A8_TRITONJSON.triton_json_dict[0]:",W8A8_TRITONJSON.triton_json_dict[0])
if len(W8A8_TRITONJSON.triton_json_dict)==0:
......@@ -429,12 +503,9 @@ def w8a8_block_int8_matmul(
"num_stages": 0,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_int8_matmul[grid](
A,
......@@ -462,6 +533,7 @@ def w8a8_block_int8_matmul(
return C
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
......@@ -585,4 +657,4 @@ def block_dequant(
i * block_k : min((i + 1) * block_k, k),
] *= x_s[j][i]
return x_dq_block
\ No newline at end of file
return x_dq_block
......@@ -305,7 +305,7 @@ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
return max(m, 64) * n < 64 * 2048 and k >= 2048
return n < 2048 and k >= 2048
def apply_gptq_marlin_linear(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment