Unverified Commit 1d0c9d6b authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel] some optimizations for dense marlin and moe marlin (#16850)


Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
parent f62cad64
...@@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped) is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, convert_to_channelwise, Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
cutlass_block_fp8_supported, cutlass_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity,
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
per_tensor_dequantize, requantize_with_max_scale) requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter, from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = False self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.fp8_linear = Fp8LinearOp( self.fp8_linear = Fp8LinearOp(
# Default to using per_token quantization if cutlass is supported # Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic=cutlass_fp8_supported()) use_per_token_if_dynamic=cutlass_fp8_supported())
...@@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
if self.block_quant: if self.block_quant:
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None assert self.quant_config.weight_block_size is not None
layer.weight_block_size = self.quant_config.weight_block_size
block_n, block_k = ( block_n, block_k = (
self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1],
...@@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase):
f"{output_partition_size} is not divisible by " f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.") f"weight quantization block_n = {block_n}.")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT # WEIGHT
weight_dtype = (torch.float8_e4m3fn weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else if self.quant_config.is_checkpoint_fp8_serialized else
...@@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase):
return weight return weight
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
size_k_first = True
# TODO(rob): refactor block quant into separate class. # TODO(rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
size_k_first = False
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \ weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
...@@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv, layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False) requires_grad=False)
return
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: elif not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None) scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)),
layer.logical_widths)
# Update the layer with the new values. # Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
...@@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase):
if self.quant_config.activation_scheme == "static": if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(layer.input_scale.data, layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False) requires_grad=False)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise. weight = layer.weight
if self.use_marlin: weight_scale = layer.weight_scale
weight = layer.weight
weight_scale = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
# If using w8a8, torch._scaled_mm needs per tensor, so # If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight. # requantize the logical shards as a single weight.
else: if not self.use_marlin:
# Dequant -> Quant with max scale so we can run per tensor. # Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
...@@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
requires_grad=False) requires_grad=False)
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin(layer) prepare_fp8_layer_for_marlin(layer, size_k_first)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale del layer.input_scale
...@@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
# Check for DeepGemm support. # Check for DeepGemm support.
self.allow_deep_gemm = False self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM: if envs.VLLM_USE_DEEP_GEMM:
...@@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
if self.block_quant: if self.block_quant:
assert self.quant_config.weight_block_size is not None assert self.quant_config.weight_block_size is not None
layer.weight_block_size = self.quant_config.weight_block_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = ( block_n, block_k = (
self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0],
...@@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv = \ layer.w2_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
return
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: elif not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype) dtype=fp8_dtype)
...@@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False) requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the # If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight # MoE kernels require single activation scale and single weight
# scale for w13 per expert. # scale for w13 per expert.
...@@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
return
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
def apply( def apply(
self, self,
...@@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )
if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import ( ...@@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method) get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer, check_marlin_supported, check_moe_marlin_supports_layer,
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, marlin_make_workspace_new, marlin_moe_permute_scales,
verify_marlin_supported) marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
PackedColumnParameter, PackedColumnParameter,
...@@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
else:
raise ValueError(
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
def create_weights( def create_weights(
self, self,
...@@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count layer.workspace = marlin_make_workspace_new(device, 4)
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
g_idx1=layer.w13_g_idx, g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
workspace=layer.workspace, workspace=layer.workspace,
is_k_full=self.is_k_full) is_k_full=self.is_k_full)
...@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops ...@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx,
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_) permute_param_layout_)
...@@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace. # Allocate marlin workspace.
self.workspace = marlin_make_workspace(c.partition_weight_shape[1], self.workspace = marlin_make_workspace_new(device)
device)
# Default names since marlin requires empty parameters for these, # Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors) # TODO: remove this requirement from marlin (allow optional tensors)
...@@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
wtype=c.weight_type, wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0], input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1], output_size_per_partition=c.partition_weight_shape[1],
has_zp=self.config.zero_points,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
bias=bias) bias=bias)
...@@ -7,12 +7,15 @@ import torch ...@@ -7,12 +7,15 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols from .quant_utils import pack_cols, unpack_cols
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MIN_THREAD_K = 128
...@@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True ...@@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
# For binary size and compile time, we don't support the same types for with and # For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl # TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool, def query_marlin_supported_quant_types(
device_capability: Optional[int] = None has_zp: bool,
): include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
if device_capability is None: if device_capability is None:
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else device_capability = (-1 if capability_tuple is None else
...@@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool, ...@@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
if has_zp: if has_zp:
# AWQ style, unsigned + runtime zero-point # AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8] return [scalar_types.uint4]
else: else:
# GPTQ style, unsigned + symmetric bias # GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able res = [scalar_types.uint4b8, scalar_types.uint8b128]
# to add `scalar_types.float8_e4m3fn` here if include_fp_type:
return [scalar_types.uint4b8, scalar_types.uint8b128] res += [scalar_types.float8_e4m3fn]
return res
def _check_marlin_supported( def _check_marlin_supported(
...@@ -62,7 +68,7 @@ def _check_marlin_supported( ...@@ -62,7 +68,7 @@ def _check_marlin_supported(
capability_tuple.to_int()) capability_tuple.to_int())
supported_types = query_marlin_supported_quant_types( supported_types = query_marlin_supported_quant_types(
has_zp, device_capability) has_zp, True, device_capability)
if quant_type not in supported_types: if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. " return (False, f"Marlin does not support weight_bits = {quant_type}. "
...@@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int, ...@@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False) requires_grad=False)
def marlin_make_workspace_new(device: torch.device,
max_blocks_per_sm: int = 1) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count
return torch.zeros(sms * max_blocks_per_sm,
dtype=torch.int,
device=device,
requires_grad=False)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel) return (not act_order) or (act_order and not is_row_parallel)
...@@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, ...@@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return output return output
def maybe_warn_marlin_atomic_add(device, dtype):
if torch.compiler.is_dynamo_compiling():
return
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
logger.info_once(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible.")
def maybe_warn_marlin_atomic_add_env():
if torch.compiler.is_dynamo_compiling():
return
if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
return
logger.info_once(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
dtype: torch.dtype) -> bool: dtype: torch.dtype) -> bool:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if n >= 2048 or k < 2048 or device.type != "cuda":
return False
# disable atomicAdd reduce by default, # disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda": if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
maybe_warn_marlin_atomic_add_env()
return False return False
# sm8x doesn't support atomicAdd + bfloat16 natively # sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device) device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16: if device_capability[0] < 9 and dtype == torch.bfloat16:
maybe_warn_marlin_atomic_add(device, dtype)
return False return False
# the performance of atomicAdd is better than global reduce return True
# only when m*n is small and k is large
return n < 2048 and k >= 2048
def apply_gptq_marlin_linear( def apply_gptq_marlin_linear(
...@@ -332,7 +378,6 @@ def apply_gptq_marlin_linear( ...@@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
wtype: ScalarType, wtype: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
has_zp: bool,
is_k_full: bool, is_k_full: bool,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
...@@ -346,6 +391,7 @@ def apply_gptq_marlin_linear( ...@@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
dtype=input.dtype) dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x, output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight, weight,
weight_scale, weight_scale,
weight_zp, weight_zp,
...@@ -358,7 +404,6 @@ def apply_gptq_marlin_linear( ...@@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
size_k=input_size_per_partition, size_k=input_size_per_partition,
is_k_full=is_k_full, is_k_full=is_k_full,
use_atomic_add=use_atomic_add, use_atomic_add=use_atomic_add,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce, use_fp32_reduce=use_fp32_reduce,
is_zp_float=False) is_zp_float=False)
...@@ -391,6 +436,7 @@ def apply_awq_marlin_linear( ...@@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
dtype=input.dtype) dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x, output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight, weight,
weight_scale, weight_scale,
weight_zp, weight_zp,
...@@ -401,8 +447,6 @@ def apply_awq_marlin_linear( ...@@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
size_m=reshaped_x.shape[0], size_m=reshaped_x.shape[0],
size_n=output_size_per_partition, size_n=output_size_per_partition,
size_k=input_size_per_partition, size_k=input_size_per_partition,
is_k_full=True,
has_zp=True,
use_atomic_add=use_atomic_add, use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce, use_fp32_reduce=use_fp32_reduce,
is_zp_float=False) is_zp_float=False)
......
...@@ -6,9 +6,11 @@ import torch ...@@ -6,9 +6,11 @@ import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
should_use_atomic_add_reduce)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -18,30 +20,40 @@ def is_fp8_marlin_supported(): ...@@ -18,30 +20,40 @@ def is_fp8_marlin_supported():
def apply_fp8_marlin_linear( def apply_fp8_marlin_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
workspace: torch.Tensor, workspace: torch.Tensor,
size_n: int, size_n: int,
size_k: int, size_k: int,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the # For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization # Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1]) reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, ) out_shape = input.shape[:-1] + (size_n, )
output = ops.fp8_marlin_gemm( use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
a=reshaped_x, n=size_n,
b_q_weight=weight, k=size_k,
b_scales=weight_scale, device=input.device,
workspace=workspace, dtype=input.dtype)
num_bits=8,
size_m=reshaped_x.shape[0], output = ops.gptq_marlin_gemm(a=reshaped_x,
size_n=size_n, c=None,
size_k=size_k, b_q_weight=weight,
) b_scales=weight_scale,
b_zeros=None,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.float8_e4m3fn,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce)
if bias is not None: if bias is not None:
output.add_(bias) # In-place add output.add_(bias) # In-place add
...@@ -50,7 +62,7 @@ def apply_fp8_marlin_linear( ...@@ -50,7 +62,7 @@ def apply_fp8_marlin_linear(
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
strategy: str = "tensor") -> None: size_k_first: bool = True) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP8 computation but " "Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will " "FP8 quantization is being used. Weight-only FP8 compression will "
...@@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, ...@@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
part_size_n = layer.output_size_per_partition part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition part_size_k = layer.input_size_per_partition
if size_k_first:
assert layer.weight.shape == (part_size_k, part_size_n)
else:
assert layer.weight.shape == (part_size_n, part_size_k)
device = layer.weight.device device = layer.weight.device
# WORKSPACE # WORKSPACE
layer.workspace = marlin_make_workspace(part_size_n, device) layer.workspace = marlin_make_workspace_new(device)
# WEIGHT # WEIGHT
# Repack weights to marlin format # Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32( perm = torch.empty(0, dtype=torch.int, device=device)
layer.weight), qweight = pack_fp8_to_int32(layer.weight, size_k_first)
perm=torch.empty(0, if not size_k_first:
dtype=torch.int, qweight = qweight.T.contiguous()
device=device),
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=part_size_k, size_k=part_size_k,
size_n=part_size_n, size_n=part_size_n,
num_bits=8) num_bits=8)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES # WEIGHT SCALES
scales = layer.weight_scale.to(layer.orig_dtype)
# Permute scales # Permute scales
if "weight_scale" in dir(layer):
scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
if layer.weight_block_size is None:
group_size = -1
else:
group_size = layer.weight_block_size[1]
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if layer.weight_block_size is None:
if scales.nelement() == 1:
# tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n)
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
assert part_size_n % scales.nelement() == 0
s_size = scales.nelement()
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (1, s_size) =>(repeat)=> (1, size_n)
scales = scales.view(1, s_size)
scales = scales.repeat_interleave(part_size_n // s_size, 1)
else:
# channel-wise quantization
# (1, size_n)
scales = scales.view(1, part_size_n)
else:
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
block_n = layer.weight_block_size[0]
scales = scales.T.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n]
marlin_scales = marlin_permute_scales(s=scales, marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k, size_k=part_size_k,
size_n=part_size_n, size_n=part_size_n,
group_size=-1) group_size=group_size)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
size_k_first: bool = True) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
# WORKSPACE
device = layer.w13_weight.device
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
for name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
if size_k_first:
assert weight.shape == (e, size_k, size_n)
else:
assert weight.shape == (e, size_n, size_k)
for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first)
if not size_k_first:
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=8)
tensor_list.append(marlin_qweight)
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
weight = torch.nn.Parameter(weight, requires_grad=False)
setattr(layer, name, weight)
# WEIGHT SCALES
# Permute scales
if layer.weight_block_size is None:
group_size = -1
else:
group_size = layer.weight_block_size[1]
for name in ["w13", "w2"]:
if name + "_weight_scale" in dir(layer):
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif name + "_weight_scale_inv" in dir(layer):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if layer.weight_block_size is None:
if scales.nelement() == e:
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
elif scales.nelement() > e and scales.nelement() != e * size_n:
assert (e * size_n) % scales.nelement() == 0
s_size = scales.nelement() // e
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, s_size)
scales = scales.repeat_interleave(size_n // s_size, 2)
else:
# channel-wise quantization
# (e, 1, size_n)
scales = scales.view(e, 1, size_n)
else:
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
block_n = layer.weight_block_size[0]
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()
for i in range(e):
marlin_scales = marlin_permute_scales(s=scales[i],
size_k=size_k,
size_n=size_n,
group_size=group_size)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
size_k_first: bool = True) -> torch.Tensor:
""" """
Repack FP8 weights to gptq format (packed int32 elements) Repack FP8 weights to gptq format (packed int32 elements)
""" """
assert fp8_tensor.dtype == torch.float8_e4m3fn assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.shape[0] % 4 == 0 assert fp8_tensor.ndim == 2
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
fp8_tensor = fp8_tensor.contiguous()
# fp8_tensor is contiguous and have shape (N, K) now
# with `.view(torch.int32)`, it become (N, K // 4)
int32_tensor = fp8_tensor.view(torch.int32)
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
# Reshape to prepare for packing def marlin_quant_fp8_torch(weight, group_size):
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) size_n, size_k = weight.shape
device = weight.device
# Convert fp8 to uint8 (byte) representation if group_size != -1:
byte_tensor = reshaped.view(torch.uint8) scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
repeated_scales = scales.repeat_interleave(group_size, 1)
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
else:
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
repeated_scales = scales.repeat_interleave(size_k, 1)
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_weight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
# Pack 4 uint8 values into one int32 marlin_scales = marlin_permute_scales(s=scales.T,
packed = (byte_tensor[:, 0].to(torch.int32) | size_k=size_k,
(byte_tensor[:, 1].to(torch.int32) << 8) | size_n=size_n,
(byte_tensor[:, 2].to(torch.int32) << 16) | group_size=group_size)
(byte_tensor[:, 3].to(torch.int32) << 24))
return packed.view(fp8_tensor.shape[0] // 4, return weight_ref.T, marlin_qweight, marlin_scales
*fp8_tensor.shape[1:]).contiguous()
...@@ -6,6 +6,8 @@ from dataclasses import dataclass ...@@ -6,6 +6,8 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Optional, Union
_SCALAR_TYPES_ID_MAP = {}
# Mirrors enum in `core/scalar_type.hpp` # Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum): class NanRepr(Enum):
...@@ -158,6 +160,8 @@ class ScalarType: ...@@ -158,6 +160,8 @@ class ScalarType:
assert offset <= 64, \ assert offset <= 64, \
f"ScalarType fields too big {offset} to fit into an int64" f"ScalarType fields too big {offset} to fit into an int64"
_SCALAR_TYPES_ID_MAP[val] = self
return val return val
@property @property
...@@ -295,6 +299,13 @@ class ScalarType: ...@@ -295,6 +299,13 @@ class ScalarType:
ret.id # noqa B018: make sure the id is cached ret.id # noqa B018: make sure the id is cached
return ret return ret
@classmethod
def from_id(cls, scalar_type_id: int):
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
raise ValueError(
f"scalar_type_id {scalar_type_id} doesn't exists.")
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
# naming generally follows: https://github.com/jax-ml/ml_dtypes # naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is: # for floating point types (leading f) the scheme is:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment