Unverified Commit 89342ce4 authored by Alex Kogan's avatar Alex Kogan Committed by GitHub
Browse files

[Quantization] [Performance] Enable Marlin GEMM kernels for the...


[Quantization] [Performance] Enable Marlin GEMM kernels for the calibration-free RTN-based quantization (#26051)
Signed-off-by: default avatarAlex Kogan <alex.kogan@oracle.com>
Signed-off-by: default avatarAlex Kogan <82225080+sakogan@users.noreply.github.com>
parent f89f5993
...@@ -6,21 +6,16 @@ import os ...@@ -6,21 +6,16 @@ import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import numpy as np
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
...@@ -31,6 +26,12 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -31,6 +26,12 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_rtn_marlin_linear,
marlin_make_workspace_new,
)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be """By default, use 8 bit as target precision, but it can be
...@@ -41,6 +42,9 @@ NUM_BITS = os.getenv("RTN_NUM_BITS", "8") ...@@ -41,6 +42,9 @@ NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
overridden by setting the RTN_GROUP_SIZE envvar overridden by setting the RTN_GROUP_SIZE envvar
""" """
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128") GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
"""Global Marlin workspace shared by all modules
"""
workspace = None
class RTNConfig(QuantizationConfig): class RTNConfig(QuantizationConfig):
...@@ -60,6 +64,10 @@ class RTNConfig(QuantizationConfig): ...@@ -60,6 +64,10 @@ class RTNConfig(QuantizationConfig):
f"supported for RTN, but got {self.weight_bits} bits." f"supported for RTN, but got {self.weight_bits} bits."
) )
self.quant_type = (
scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})" f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
...@@ -221,7 +229,15 @@ class RTNLinearMethod(LinearMethodBase): ...@@ -221,7 +229,15 @@ class RTNLinearMethod(LinearMethodBase):
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
fix_weights(layer, "weight") """Repack weights and scales for Marlin kernels."""
weight_bits = self.quant_config.weight_bits
weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)
replace_parameter(layer, "weight", weight)
replace_parameter(layer, "scale", scale)
init_workspace(layer.weight.device)
def apply( def apply(
self, self,
...@@ -229,16 +245,16 @@ class RTNLinearMethod(LinearMethodBase): ...@@ -229,16 +245,16 @@ class RTNLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
qweight = layer.weight return apply_rtn_marlin_linear(
scale = layer.scale input=x,
weight=layer.weight,
weight = rtn_dequantize(qweight, scale) weight_scale=layer.scale,
out = F.linear(x, weight) workspace=workspace,
del weight quant_type=self.quant_config.quant_type,
if bias is not None: output_size_per_partition=layer.output_size_per_partition,
out.add_(bias) input_size_per_partition=layer.input_size_per_partition,
bias=bias,
return out )
class RTNMoEMethod(FusedMoEMethodBase): class RTNMoEMethod(FusedMoEMethodBase):
...@@ -315,28 +331,27 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -315,28 +331,27 @@ class RTNMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Repack weights and scales for Marlin kernels."""
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
fix_weights(layer, "w13_weight", weight_bits == 4)
fix_weights(layer, "w2_weight", weight_bits == 4) w13_weight, w13_scale = repack_weights(
layer.w13_weight, layer.w13_scale, weight_bits
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_scale", w13_scale)
w2_weight, w2_scale = repack_weights(
layer.w2_weight, layer.w2_scale, weight_bits
)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_scale", w2_scale)
init_workspace(layer.w13_weight.device)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
weight_bits = self.quant_config.weight_bits return None
group_size = self.quant_config.group_size
assert weight_bits == 4 or weight_bits == 8
config_builder = (
int4_w4a16_moe_quant_config
if weight_bits == 4
else int8_w8a16_moe_quant_config
)
return config_builder(
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, group_size],
)
def apply( def apply(
self, self,
...@@ -366,8 +381,6 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -366,8 +381,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -383,18 +396,22 @@ class RTNMoEMethod(FusedMoEMethodBase): ...@@ -383,18 +396,22 @@ class RTNMoEMethod(FusedMoEMethodBase):
indices_type=self.topk_indices_dtype, indices_type=self.topk_indices_dtype,
) )
return fused_experts( return torch.ops.vllm.fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, getattr(layer, "w13_bias", None),
topk_ids=topk_ids, getattr(layer, "w2_bias", None),
inplace=True, layer.w13_scale,
activation=activation, layer.w2_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_config.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
quant_config=self.moe_quant_config, workspace=workspace,
) )
...@@ -504,18 +521,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: ...@@ -504,18 +521,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return input_deq return input_deq
def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False): def _get_perms():
"""torch.compile does not know how to deal with a Parameter subclass perm = []
(aka RTNParameter). As we don't really need RTNParameters for the for i in range(32):
forward pass, we replace them with equivalent instances of Parameters. perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm_arr = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
perm_tensor = torch.from_numpy(perm_arr)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm_tensor, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
def pack_for_marlin(weight, scale, qbits):
batch = weight.shape[0]
n = weight.size(1)
k = weight.size(2)
groupsize = k // scale.size(2)
tile = 16
s = scale.permute(0, 2, 1) # transpose
w = weight.permute(0, 2, 1) # transpose
if groupsize != k:
w = w.reshape((batch, -1, groupsize, n))
w = w.permute(0, 2, 1, 3)
w = w.reshape((batch, groupsize, -1))
s = s.reshape((batch, 1, -1))
if groupsize != k:
w = w.reshape((batch, groupsize, -1, n))
w = w.permute(0, 2, 1, 3)
w = w.reshape((batch, k, n)).contiguous()
s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
else:
s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
s = s.reshape((batch, -1, n)).contiguous()
w = w.reshape((batch, k // tile, tile, n // tile, tile))
w = w.permute((0, 1, 3, 2, 4))
w = w.reshape((batch, k // tile, n * tile))
res = w
res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
if qbits == 4:
q = torch.zeros(
(batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
)
for i in range(2):
q |= res[:, :, i::2] << 4 * i
q = q.reshape(batch, -1, n).contiguous()
else:
q = res.clone()
q[:, :, 2::8] = res[:, :, 4::8]
q[:, :, 3::8] = res[:, :, 5::8]
q[:, :, 4::8] = res[:, :, 2::8]
q[:, :, 5::8] = res[:, :, 3::8]
q = q.reshape(batch, -1, n).to(torch.int8).contiguous()
return q, s
def repack_8bit_into_32bit(input):
output = torch.zeros(
(input.shape[0], input.shape[1], input.shape[2] // 4),
dtype=torch.int32,
device=input.device,
)
for i in range(4):
output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i
return output
def repack_weights(qweight, scale, weight_bits):
batch_present = len(qweight.shape) == 3
if not batch_present:
qweight = qweight.unsqueeze(0)
scale = scale.unsqueeze(0)
if weight_bits == 4:
"""Unpack two 4-bit values from each byte.
"""
qweight_unpacked = torch.empty(
(qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
dtype=torch.uint8,
device=qweight.device,
)
for i in range(2):
qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
)
else:
qweight_unpacked = qweight
qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
"""Marlin kernels expect tensors in int32 format in a certain shape
""" """
old_weight = getattr(layer, param_name) qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
assert isinstance(old_weight, RTNParameter) qweight_reshaped = qweight_repacked.reshape(
data = old_weight.data.data qweight.shape[0], qweight.shape[2] // 16, -1
)
if not batch_present:
qweight_reshaped = qweight_reshaped.squeeze(0)
scale_packed = scale_packed.squeeze(0)
return qweight_reshaped, scale_packed
delattr(layer, param_name)
if reshape: def init_workspace(device):
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1) global workspace
new_weight = Parameter(data=data, requires_grad=False) if workspace is None:
layer.register_parameter(param_name, new_weight) workspace = marlin_make_workspace_new(device, 4)
...@@ -528,3 +528,48 @@ def apply_awq_marlin_linear( ...@@ -528,3 +528,48 @@ def apply_awq_marlin_linear(
) )
return output.reshape(out_shape) return output.reshape(out_shape)
def apply_rtn_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
None,
None,
None,
None,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
return output.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment