Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
......@@ -518,6 +518,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -547,6 +548,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
routed_scaling_factor=routed_scaling_factor,
......
......@@ -466,6 +466,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -490,6 +491,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.quant_config.load_in_8bit:
......
......@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.transform import TransformConfig
from pydantic import BaseModel
import vllm.envs as envs
......@@ -29,10 +30,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod, get_linear_transform_schemes)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
......@@ -67,6 +70,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
transform_config: Optional[TransformConfig] = None,
):
super().__init__()
self.ignore = ignore
......@@ -78,6 +82,12 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
if transform_config is not None:
self.transform_config = TransformConfig.model_validate(
transform_config)
else:
self.transform_config = None
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
......@@ -110,18 +120,26 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
input_tfms, output_tfms = get_linear_transform_schemes(
layer, prefix, self.transform_config,
self.packed_modules_mapping)
# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
quant_method = CompressedTensorsLinearMethod(self)
# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, input_tfms, output_tfms)
else:
return quant_method
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
......@@ -136,6 +154,7 @@ class CompressedTensorsConfig(QuantizationConfig):
config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)
transform_config = config.get("transform_config")
return cls(
target_scheme_map=target_scheme_map,
......@@ -144,6 +163,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
)
@classmethod
......@@ -207,8 +227,10 @@ class CompressedTensorsConfig(QuantizationConfig):
format
) if format is not None else is_activation_quantization_format(
quant_format)
if act_quant_format:
input_activations = quant_config.get("input_activations")
# TODO(czhu): w4a8fp8 is in packed-quantized format
# but needs input activation quantization
input_activations = quant_config.get("input_activations")
if act_quant_format or input_activations:
# The only case where we have activation quant supported
# but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where
......@@ -359,6 +381,28 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w4a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
if not weight_quant or not input_quant:
return False
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
# Only per-group symmetric weight (4bit)
# + per-tok symmetric activation (8bit) quantization supported.
return (is_weight_4_bits and is_activation_8_bits and is_token
and is_symmetric and is_dynamic)
def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w4a8(weight_quant, input_quant))
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
......@@ -408,19 +452,30 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant: BaseModel,
input_quant: BaseModel,
format: Optional[str] = None) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
# Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value
if (format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
assert weight_quant.symmetric
return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size)
if (self.quant_format == CompressionFormat.pack_quantized.value
if (format == CompressionFormat.pack_quantized.value
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
......@@ -429,10 +484,7 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
act_quant_format = is_activation_quantization_format(
format
) if format is not None else is_activation_quantization_format(
self.quant_format)
act_quant_format = is_activation_quantization_format(format)
if act_quant_format:
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if cutlass_fp4_supported(
......@@ -512,9 +564,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
# TODO (@kylesayrs): support ignore module names with ct matching utils
if should_ignore_layer(layer_name,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None
# Will be empty for models with only sparsity
weight_quant = input_quant = None
......@@ -531,7 +585,7 @@ class CompressedTensorsConfig(QuantizationConfig):
format = scheme_dict.get("format")
# Find the sparsity scheme of the layer
# assume that fused layers inerhit first component's sparsity scheme
# assume that fused layers inherit first component's sparsity scheme
sparsity_targets = (self.sparsity_scheme_map.keys() -
set(self.sparsity_ignore_list))
sparsity_scheme: Optional[SparsityCompressionConfig] = None
......@@ -720,7 +774,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer input. See LinearMethodBase for param details
"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
......
......@@ -22,6 +22,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
......@@ -66,12 +68,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
layer: torch.nn.Module
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get(
# Check if a using "Linear" to select schemes
if "Linear" in quant_config.target_scheme_map:
matched_target = "Linear"
else:
# May have instead defined the linear layers in the fused model
fused_layers = [
"re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"
]
current_scheme = None
for fused_layer in fused_layers:
# Check if one of the fused layers are defined in quant_config
matched_target = find_matched_target(
layer_name=fused_layer,
module=layer,
targets=quant_config.target_scheme_map.keys(),
fused_mapping=quant_config.packed_modules_mapping)
# Only valid if down_proj, gate_proj, and up_proj
# are mapped to the same quant scheme in the quant_config
if current_scheme is None:
current_scheme = quant_config.target_scheme_map.get(
matched_target)
else:
assert current_scheme == quant_config.target_scheme_map.get(
matched_target)
weight_quant = quant_config.target_scheme_map[matched_target].get(
"weights")
input_quant = quant_config.target_scheme_map[matched_target].get(
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
......@@ -247,13 +277,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return
# swizzle weight scales
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
layer.w13_weight_scale),
requires_grad=False)
requires_grad=False)
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
layer.w2_weight_scale),
requires_grad=False)
requires_grad=False)
# w13
w13_input_global_scale = layer.w13_input_global_scale.max(
......@@ -293,6 +323,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl(
......@@ -320,6 +351,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -345,6 +377,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
......@@ -384,8 +417,35 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.allow_flashinfer:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4)
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -401,8 +461,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
......@@ -663,11 +723,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
pass
if self.use_cutlass:
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full(
(layer.local_num_experts, ),
layer.hidden_size,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full(
(layer.local_num_experts, ),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full(
(layer.local_num_experts, ),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import (
......@@ -687,6 +765,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
......@@ -694,6 +776,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
)
self.disable_expert_map = (num_dispatchers > 1
......@@ -746,6 +832,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -773,9 +860,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
use_fused_gate=use_fused_gate,
)
# cutlass path
......@@ -821,6 +908,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
......@@ -1015,6 +1106,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -1046,9 +1138,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
indices_type=self.topk_indices_dtype,
use_fused_gate=use_fused_gate)
return fused_experts(
hidden_states=x,
......@@ -1325,6 +1417,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -1354,6 +1447,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
......@@ -1557,6 +1651,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -1583,6 +1678,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
......
......@@ -3,6 +3,7 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
......@@ -21,5 +22,6 @@ __all__ = [
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int"
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8"
]
......@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
......@@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer) -> None:
global_input_scale = layer.input_global_scale.max().to(torch.float32)
......@@ -133,13 +112,12 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale_swizzled = Parameter(weight_scale,
requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
......@@ -157,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x=x,
input_global_scale=layer.input_global_scale,
weight=layer.weight_packed,
weight_scale_swizzled=layer.weight_scale_swizzled,
weight_scale_swizzled=layer.weight_scale,
weight_global_scale=layer.weight_global_scale)
if bias is not None:
out = out + bias
......@@ -170,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
layer.weight_scale_swizzled, layer.alpha, output_dtype)
layer.weight_scale, layer.alpha, output_dtype)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
# yapf: enable
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A8Fp8"]
W4A8_SUPPORTED_TYPES_MAP = {
4: scalar_types.int4,
}
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size != 128 or self.strategy != "group":
raise ValueError("W4A8 kernels require group quantization " \
"with group size 128")
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
@classmethod
def get_min_capability(cls) -> int:
# hopper
return 90
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
group_size=self.group_size,
zero_points=not self.symmetric,
has_g_idx=self.has_g_idx,
out_type=params_dtype
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW4A8Fp8",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition //
self.pack_factor,
dtype=torch.int32,
))
# TODO(czhu): allocate the packed fp8 scales memory here?
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=torch.float8_e4m3fn,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
# per-channel scales
weight_chan_scale = ChannelQuantScaleParameter(
data=torch.empty((output_size_per_partition, 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("weight_chan_scale", weight_chan_scale)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Generator
from itertools import accumulate
from typing import Callable, Optional
import torch
from compressed_tensors.transform import (TransformArgs, TransformConfig,
TransformLocation, TransformScheme)
from compressed_tensors.utils import is_match
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase,
QKVCrossParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
HadamardTransform)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple)
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
"""
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
input and output transforms to either side of the original apply method
"""
@classmethod
def from_schemes(
cls, quant_method: LinearMethodBase, input_tfms: dict[int,
TransformTuple],
output_tfms: dict[int, TransformTuple]
) -> "CompressedTensorsLinearTransformMethod":
assert input_tfms or output_tfms
# TODO (@ksayers): implement QutlassLinearMethodNvFP4
# hadacore and fwht can be selected by Transform module
return cls(quant_method, input_tfms, output_tfms)
def __init__(self, quant_method: LinearMethodBase,
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple]):
self.quant_method = quant_method
self.input_tfms = input_tfms
self.output_tfms = output_tfms
self.input_transform: Optional[HadamardTransform] = None
self.output_transform: Optional[HadamardTransform] = None
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
# get weight loader for transforms
weight_loader: Callable = extra_weight_attrs.get(
"weight_loader") # type: ignore[assignment]
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
# transforms (specifically SharedWeightParameter) requires
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
# hack around this by getting weight loader v1 so ULM can load correctly
quant_method_name = self.quant_method.__class__.__name__
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
if isinstance(layer, QKVCrossParallelLinear):
weight_loader_v1 = layer.weight_loader_v1
else:
weight_loader_v1 = layer.weight_loader
extra_weight_attrs["weight_loader"] = weight_loader_v1
self.quant_method.create_weights(
layer=layer,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size,
params_dtype=params_dtype,
**extra_weight_attrs)
# validate schemes
num_partitions = len(output_partition_sizes)
self._validate_tfm_schemes(num_partitions)
# create submodules for weight loading
if len(self.input_tfms) > 0:
scheme_name = list(self.input_tfms.values())[0].scheme_name
location = list(self.input_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(self.input_tfms, layer,
weight_loader,
input_size_per_partition,
output_partition_sizes)
layer.register_module(transform_name, transform)
self.input_transform = transform
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(self.output_tfms, layer,
weight_loader,
input_size_per_partition,
output_partition_sizes)
layer.register_module(transform_name, transform)
self.output_transform = transform
# compute partition ranges for slicing activations
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
self.partition_ranges = list(zip(starts, output_partition_sizes))
def process_weights_after_loading(self, layer):
self.quant_method.process_weights_after_loading(layer)
for submodule in layer.children():
if isinstance(submodule, HadamardTransform):
submodule.process_weights_after_loading()
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.input_transform is not None:
x = self.input_transform(x)
assert bias is None
x = self.quant_method.apply(layer, x, bias)
# TODO (@ksayers): Write a triton kernel to do this in parallel
if self.output_transform is not None:
for part_id, (start, length) in enumerate(self.partition_ranges):
x[:, start:start + length] = self.output_transform(
x[:, start:start + length], part_id=part_id)
return x
def _validate_tfm_schemes(self, num_partitions: int):
if len(self.input_tfms) > 0:
if 0 not in self.input_tfms:
raise ValueError("Must have same input")
for part_index in range(num_partitions):
if self.input_tfms[part_index] != self.input_tfms[0]:
raise ValueError("Must have same input")
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
for tfm in self.output_tfms.values():
if tfm.scheme_name != scheme_name:
raise ValueError("Must have same scheme name")
if tfm.args.location != location:
raise ValueError("Must have same location")
return self.input_tfms, self.output_tfms
def get_linear_transform_schemes(
layer: torch.nn.Module, layer_name: str,
transform_config: Optional[TransformConfig],
packed_modules_mapping: dict[str, list[str]]
) -> tuple[dict[int, TransformTuple], dict[
int, TransformTuple]]: # [input_transform, [output_transform, ...]]
# there can only be one transform input scheme per (fused) module
input_tfms = {}
output_tfms = {}
partition_names = get_layer_partition_names(layer_name,
packed_modules_mapping)
for scheme_name, scheme, args in get_schemes_args(transform_config):
for part_index, part_name in enumerate(partition_names):
if is_match(part_name, layer, args.targets,
args.ignore) and args.is_online():
if args.location == TransformLocation.INPUT:
input_tfms[part_index] = TransformTuple(
scheme_name, scheme, args)
elif args.location == TransformLocation.OUTPUT:
output_tfms[part_index] = TransformTuple(
scheme_name, scheme, args)
else:
raise ValueError(f"Cannot apply `{args.location}` "
f"transform to `{layer_name}`")
return (input_tfms, output_tfms)
def get_schemes_args(
transform_config: Optional[TransformConfig]
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
if transform_config is None:
return
for scheme_name, scheme in transform_config.config_groups.items():
for args in scheme.apply:
yield (scheme_name, scheme, args)
def get_layer_partition_names(
layer_name: str, packed_modules_mapping: dict[str,
list[str]]) -> list[str]:
"""
Get all partition names associated with this layer.
Names are returned in order of their partition indices.
```python
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
assert get_layer_partition_names(
"mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"]
assert get_layer_partition_names(
"mlp.down_proj", mapping) == ["down_proj"]
"""
for fused_suffix, part_suffixes in packed_modules_mapping.items():
if layer_name.endswith(fused_suffix):
return [
layer_name.removesuffix(fused_suffix) + part_suffix
for part_suffix in part_suffixes
]
return [layer_name]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Hashable
from typing import Callable, Optional
import torch
from compressed_tensors.transform import TransformLocation, TransformScheme
from torch import Tensor
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parameter import SharedWeightParameter
class HadamardTransform(torch.nn.Module):
"""
Class which handles weight loading, postprocessing, and application of
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
and attention transforms method (not implemented yet)
"""
transforms: dict[int, TransformTuple] # info parsed from transforms config
weight: SharedWeightParameter # container for shared tensors
kernel: Callable # function used during application
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
def __init__(self,
transforms: dict[int, TransformTuple],
layer: torch.nn.Module,
weight_loader: Callable,
input_size_per_partition: int,
output_partition_sizes: list[int],
kernel: Optional[Callable] = None):
super().__init__()
self.transforms = transforms
self.scales = {}
if get_tensor_model_parallel_world_size() > 1:
raise NotImplementedError("Online transforms with tensor "
"parallelism is not supported")
# Similar to row/col parallel params, but tensors are separate
# to allow for loading with shared memory
self.weight = SharedWeightParameter(weight_loader=weight_loader)
# create shared partition data for each partition of the original weight
input_size = input_size_per_partition
for part_index, (_scheme_name, scheme,
args) in self.transforms.items():
output_size = output_partition_sizes[part_index]
weight_size = self._get_weight_size(layer, args.location,
input_size, output_size)
data_key = self._get_data_key(scheme, weight_size)
self.weight.add_partition(
part_index,
data_key,
size=(weight_size, weight_size),
dtype=scheme.precision,
)
# validate that shared tensors and schemes are correct
self._validate_input_transforms()
# select kernel based on transform schemes
self.kernel = self._infer_kernel() if kernel is None else kernel
def process_weights_after_loading(self):
for part_id in self.weight.partitions:
data = self.weight.partitions[part_id].data
# required by torch.compile
self.weight.process_weights_after_loading()
# precompute scale as a runtime multiply, not division
# do not fold into weight in order to utilize FWHT
self.scales[part_id] = 1 / math.sqrt(data.size(0))
# FUTURE: avoid runtime tranpose by processing weights
# prior to apply
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
if part_id not in self.weight.partitions:
return value
weight = self.weight.partitions[part_id]
weight = weight if self.transforms[
part_id].args.inverse else weight.T # linear := x(W.T)
scale = self.scales[part_id]
return self.kernel(self, value.to(weight.dtype), weight, None).to(
value.dtype) * scale
def _get_data_key(self, scheme: TransformScheme,
weight_size: int) -> Hashable:
return (id(scheme), weight_size)
def _get_weight_size(self, layer: torch.nn.Module,
location: TransformLocation, input_size: int,
output_size: int) -> int:
if isinstance(layer, LinearBase):
if location == TransformLocation.INPUT:
return input_size
elif location == TransformLocation.OUTPUT:
return output_size
elif isinstance(layer, VocabParallelEmbedding):
if location == TransformLocation.INPUT:
return output_size
elif location == TransformLocation.OUTPUT:
return input_size
raise ValueError()
def _validate_input_transforms(self):
assert len(self.transforms) > 0
location = list(self.transforms.values())[0].args.location
if location == TransformLocation.INPUT:
first_data = self.weight.partitions[0].data
for partition in self.weight.partitions.values():
if partition.data.data_ptr() != first_data.data_ptr():
raise ValueError("")
def _infer_kernel(self) -> Callable:
# TODO (@ksayers): use fwht, hadacore
return dispatch_unquantized_gemm()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod)
# Because qutlass fuses hadamard with quantization, it cannot automatically be
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
# Therefore, a separate scheme must be created for each quantized dtype
class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod):
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# fused hadamard quant linear method
raise NotImplementedError()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from typing import NamedTuple
from compressed_tensors.transform import TransformArgs, TransformScheme
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
\ No newline at end of file
__all__ = ["TransformTuple"]
class TransformTuple(NamedTuple):
scheme_name: str
scheme: TransformScheme
args: TransformArgs
......@@ -120,6 +120,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -146,6 +147,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment