Unverified Commit 410d3008 authored by Matthias Gehre's avatar Matthias Gehre Committed by GitHub
Browse files

[ROCm][Refactor] Enable AWQMarlinConfig on ROCm to use choose_mp_linear_kernel (#36505)


Signed-off-by: default avatarMatthias Gehre <matthias.gehre@amd.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent d3fe8571
...@@ -113,6 +113,8 @@ class ConchLinearKernel(MPLinearKernel): ...@@ -113,6 +113,8 @@ class ConchLinearKernel(MPLinearKernel):
self._transform_param(layer, self.w_s_name, transform_w_s) self._transform_param(layer, self.w_s_name, transform_w_s)
if self.config.zero_points: if self.config.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp) self._transform_param(layer, self.w_zp_name, transform_w_zp)
elif self.w_zp_name is not None:
layer.register_parameter(self.w_zp_name, None)
def apply_weights( def apply_weights(
self, self,
......
...@@ -10,6 +10,10 @@ from torch.nn import Parameter ...@@ -10,6 +10,10 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -34,21 +38,16 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -34,21 +38,16 @@ from vllm.model_executor.layers.quantization.base_config import (
) )
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear,
awq_to_marlin_zero_points,
check_marlin_supported, check_marlin_supported,
check_marlin_supports_layer, check_marlin_supports_layer,
check_moe_marlin_supports_layer, check_moe_marlin_supports_layer,
get_marlin_input_dtype, get_marlin_input_dtype,
marlin_act_int8_process_scales, marlin_act_int8_process_scales,
marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
marlin_permute_bias, marlin_permute_bias,
marlin_permute_scales,
moe_awq_to_marlin_zero_points, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supported,
verify_marlin_supports_shape,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -63,6 +62,90 @@ if TYPE_CHECKING: ...@@ -63,6 +62,90 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
# AWQ uses a non-standard packing order within int32 values.
# For 4-bit: standard order stores values at bit positions [0,4,8,12,16,20,24,28]
# for indices [0,1,2,3,4,5,6,7], while AWQ stores them for indices
# [0,4,1,5,2,6,3,7]. This permutation reverses that ordering.
_REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def _convert_awq_to_standard_format(
layer: torch.nn.Module,
w_q_name: str,
w_zp_name: str,
size_bits: int,
) -> None:
"""Convert AWQ weight and zero-point tensors to standard GPTQ-like format.
AWQ packs qweight along the output dim with a non-standard bit order.
This converts to standard bit order and repacks qweight along the input
dim, matching the format expected by the MPLinearKernel framework.
"""
pack_factor = 32 // size_bits
mask = (1 << size_bits) - 1
device = getattr(layer, w_q_name).device
reverse_order = torch.tensor(
_REVERSE_AWQ_PACK_ORDER, dtype=torch.long, device=device
)
shifts = torch.arange(0, 32, size_bits, dtype=torch.int32, device=device)
# --- Convert qweight: (K, N // pack) packed_dim=1 → (K // pack, N) packed_dim=0
qw = getattr(layer, w_q_name).data
K, N_packed = qw.shape
N = N_packed * pack_factor
# Unpack int32 → individual values, fix AWQ ordering
unpacked = (qw.unsqueeze(-1) >> shifts) & mask # (K, N_packed, pack_factor)
unpacked = unpacked[:, :, reverse_order]
unpacked = unpacked.reshape(K, N) # (K, N)
# Repack along input dim (dim 0)
unpacked = unpacked.reshape(K // pack_factor, pack_factor, N)
new_qw = (unpacked.to(torch.int32) << shifts[None, :, None]).sum(
dim=1, dtype=torch.int32
)
def _noop_loader(*args, **kwargs):
pass
new_param = PackedvLLMParameter(
data=new_qw.contiguous(),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=pack_factor,
weight_loader=_noop_loader,
)
setattr(layer, w_q_name, new_param)
# --- Convert qzeros: fix AWQ bit ordering and repack
# AWQ qzeros: (G, N // pack) packed along dim 1, AWQ bit order
# Target: (N // pack, G) packed along dim 0, standard bit order
# This matches the CompressedTensors layout expected by the kernels.
qz = getattr(layer, w_zp_name).data
G, _ = qz.shape
unpacked_zp = (qz.unsqueeze(-1) >> shifts) & mask # (G, N_packed, pack_factor)
unpacked_zp = unpacked_zp[:, :, reverse_order]
unpacked_zp = unpacked_zp.reshape(G, N) # (G, N) individual values
# Transpose and repack along dim 0 (output dim)
unpacked_zp = unpacked_zp.T # (N, G)
unpacked_zp = unpacked_zp.reshape(N // pack_factor, pack_factor, G)
new_qz = (unpacked_zp.to(torch.int32) << shifts[None, :, None]).sum(
dim=1, dtype=torch.int32
)
new_zp_param = PackedvLLMParameter(
data=new_qz.contiguous(),
output_dim=0,
input_dim=1,
packed_dim=0,
packed_factor=pack_factor,
weight_loader=_noop_loader,
)
setattr(layer, w_zp_name, new_zp_param)
class AWQMarlinConfig(QuantizationConfig): class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin""" """Config class for AWQ Marlin"""
...@@ -226,7 +309,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -226,7 +309,7 @@ class AWQMarlinConfig(QuantizationConfig):
group_size = quant_config.get("group_size") group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point") zero_point = quant_config.get("zero_point")
if not current_platform.is_cuda(): if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
return False return False
if quant_method != "awq": if quant_method != "awq":
...@@ -268,15 +351,26 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -268,15 +351,26 @@ class AWQMarlinConfig(QuantizationConfig):
class AWQMarlinLinearMethod(LinearMethodBase): class AWQMarlinLinearMethod(LinearMethodBase):
"""Linear method for AWQ Marlin. """Linear method for AWQ Marlin.
Uses choose_mp_linear_kernel to select the best available kernel
(Conch, Exllama, or Marlin) for the current platform.
Args: Args:
quant_config: The AWQ Marlin quantization config. quant_config: The AWQ Marlin quantization config.
""" """
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: AWQMarlinConfig) -> None: def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.quant_type = scalar_types.uint4 self.quant_type = scalar_types.uint4
self.input_dtype = None self.input_dtype = None
verify_marlin_supported(
quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size,
has_zp=self.quant_config.zero_point,
)
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -287,23 +381,35 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -287,23 +381,35 @@ class AWQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1: if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size group_size = self.quant_config.group_size
else: else:
group_size = input_size group_size = input_size
verify_marlin_supports_shape( mp_linear_kernel_config = MPLinearLayerConfig(
output_size_per_partition=output_size_per_partition, full_weight_shape=(input_size, output_size),
input_size_per_partition=input_size_per_partition, partition_weight_shape=(
input_size=input_size, input_size_per_partition,
group_size=group_size, output_size_per_partition,
),
weight_type=self.quant_config.quant_type,
act_type=params_dtype if self.input_dtype is None else self.input_dtype,
group_size=self.quant_config.group_size,
zero_points=self.quant_config.zero_point,
has_g_idx=False,
) )
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for AWQMarlinLinearMethod", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Weights are loaded in AWQ checkpoint format (packed along output dim).
# Conversion to GPTQ-like format happens in process_weights_after_loading.
qweight = PackedvLLMParameter( qweight = PackedvLLMParameter(
data=torch.empty( data=torch.empty(
input_size_per_partition, input_size_per_partition,
...@@ -318,7 +424,6 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -318,7 +424,6 @@ class AWQMarlinLinearMethod(LinearMethodBase):
) )
num_groups = input_size_per_partition // group_size num_groups = input_size_per_partition // group_size
layer.num_groups = num_groups
qzeros = PackedvLLMParameter( qzeros = PackedvLLMParameter(
data=torch.empty( data=torch.empty(
...@@ -348,73 +453,22 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -348,73 +453,22 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.register_parameter("qzeros", qzeros) layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
layer.input_size_per_partition = input_size_per_partition self.kernel = kernel_type(
layer.output_size_per_partition = output_size_per_partition mp_linear_kernel_config,
layer.num_groups = num_groups w_q_param_name="qweight",
w_s_param_name="scales",
# TODO: Update this docs w_zp_param_name="qzeros",
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
layer.scales.data = layer.scales.data * 512
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
marlin_scales = marlin_permute_scales( # AWQ checkpoints use a non-standard packing order and pack qweight
layer.scales, # along the output dimension. Convert to the standard format
size_k=layer.input_size_per_partition, # (GPTQ-like: standard bit order, qweight packed along input dim)
size_n=layer.output_size_per_partition, # before handing off to the kernel.
group_size=self.quant_config.group_size, _convert_awq_to_standard_format(
is_a_8bit=is_a_8bit, layer, "qweight", "qzeros", self.quant_config.quant_type.size_bits
)
if self.input_dtype == torch.int8 and layer.num_groups > 1:
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
marlin_scales
)
layer.register_parameter(
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "qzeros", marlin_zp) self.kernel.process_weights_after_loading(layer)
# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data = marlin_permute_bias(layer.bias)
def apply( def apply(
self, self,
...@@ -422,21 +476,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -422,21 +476,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return apply_awq_marlin_linear( return self.kernel.apply_weights(layer, x, bias)
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.qzeros,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=self.input_dtype,
)
class AWQMarlinMoEMethod(FusedMoEMethodBase): class AWQMarlinMoEMethod(FusedMoEMethodBase):
......
...@@ -46,6 +46,7 @@ def query_marlin_supported_quant_types( ...@@ -46,6 +46,7 @@ def query_marlin_supported_quant_types(
if current_platform.is_cpu(): if current_platform.is_cpu():
return _query_cpu_marlin_supported_quant_types(has_zp, include_fp_type) return _query_cpu_marlin_supported_quant_types(has_zp, include_fp_type)
if not current_platform.is_rocm():
if device_capability is None: if device_capability is None:
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = ( device_capability = (
...@@ -210,8 +211,6 @@ def check_marlin_supports_shape( ...@@ -210,8 +211,6 @@ def check_marlin_supports_shape(
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
if current_platform.is_rocm():
return False
output_size_per_partition = ( output_size_per_partition = (
getattr(layer, "output_size_per_partition", None) or layer.output_size getattr(layer, "output_size_per_partition", None) or layer.output_size
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment