Unverified Commit eb5cb5e5 authored by Dipika Sikka's avatar Dipika Sikka Committed by GitHub
Browse files

[BugFix] Fix parameter names and `process_after_weight_loading` for W4A16 MoE...


[BugFix] Fix parameter names and `process_after_weight_loading` for W4A16 MoE Group Act Order  (#11528)
Signed-off-by: default avatarElizaWszola <eliza@neuralmagic.com>
Co-authored-by: default avatarElizaWszola <eliza@neuralmagic.com>
Co-authored-by: default avatarMichael Goin <michael@neuralmagic.com>
parent 2cbeedad
...@@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod @abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError raise NotImplementedError
...@@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts, w13_weight = torch.nn.Parameter(torch.empty(
2 * intermediate_size, num_experts,
hidden_size, 2 * intermediate_size_per_partition,
dtype=params_dtype), hidden_size,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts, w2_weight = torch.nn.Parameter(torch.empty(
hidden_size, num_experts,
intermediate_size, hidden_size,
dtype=params_dtype), intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
...@@ -289,13 +291,20 @@ class FusedMoE(torch.nn.Module): ...@@ -289,13 +291,20 @@ class FusedMoE(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix) self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights( moe_quant_params = {
layer=self, "num_experts": num_experts,
num_experts=num_experts, "hidden_size": hidden_size,
hidden_size=hidden_size, "intermediate_size_per_partition":
intermediate_size=self.intermediate_size_per_partition, self.intermediate_size_per_partition,
params_dtype=params_dtype, "params_dtype": params_dtype,
weight_loader=self.weight_loader) "weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
...@@ -312,19 +321,30 @@ class FusedMoE(torch.nn.Module): ...@@ -312,19 +321,30 @@ class FusedMoE(torch.nn.Module):
elif shard_id == "w2": elif shard_id == "w2":
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(self, shard_dim: int, def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor, expert_data: torch.Tensor,
shard_id: str, shard_id: str,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
tp_rank: int): tp_rank: int,
# Load grouped weight scales for group quantization load_full_w2: bool = False):
# or model weights """
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2": if shard_id == "w2":
self._load_w2(shard_id=shard_id, # In the case where we have actorder/g_idx, we do not partition the
shard_dim=shard_dim, # w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank) tp_rank=tp_rank,
load_full=load_full_w2)
elif shard_id in ("w1", "w3"): elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id, self._load_w13(shard_id=shard_id,
shard_dim=shard_dim, shard_dim=shard_dim,
...@@ -364,15 +384,21 @@ class FusedMoE(torch.nn.Module): ...@@ -364,15 +384,21 @@ class FusedMoE(torch.nn.Module):
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, def _load_w2(self,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding. # Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim # down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load. # Narrow parameter and load.
shard_size = expert_data.shape[shard_dim] shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, if not load_full:
shard_size) loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2. # w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
...@@ -387,8 +413,7 @@ class FusedMoE(torch.nn.Module): ...@@ -387,8 +413,7 @@ class FusedMoE(torch.nn.Module):
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
if shard_id == "w2": if shard_id == "w2":
self._load_w2(shard_id=shard_id, self._load_w2(shard_dim=shard_dim,
shard_dim=shard_dim,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank) tp_rank=tp_rank)
...@@ -416,7 +441,7 @@ class FusedMoE(torch.nn.Module): ...@@ -416,7 +441,7 @@ class FusedMoE(torch.nn.Module):
] ]
# Fetch the dim to shard the parameter/loaded weight # Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever # based on the shard id. This will be whatever
# dimension intermediate_size is used. # dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id] expert_data = param.data[expert_id]
...@@ -424,11 +449,11 @@ class FusedMoE(torch.nn.Module): ...@@ -424,11 +449,11 @@ class FusedMoE(torch.nn.Module):
# is_transposed: if the dim to shard the weight # is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors # should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is # should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False) is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed: if is_transposed:
shard_dim = ~shard_dim shard_dim = int(not shard_dim)
# Case input scale: input_scale loading is only supported for fp8 # Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name: if "input_scale" in weight_name:
...@@ -480,7 +505,8 @@ class FusedMoE(torch.nn.Module): ...@@ -480,7 +505,8 @@ class FusedMoE(torch.nn.Module):
shard_dim=shard_dim, shard_dim=shard_dim,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank) tp_rank=tp_rank,
load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id, self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param, param=param,
......
...@@ -303,7 +303,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -303,7 +303,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({ extra_weight_attrs.update({
"is_transposed": "is_transposed":
...@@ -312,17 +312,18 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -312,17 +312,18 @@ class AWQMoEMethod(FusedMoEMethodBase):
FusedMoeWeightScaleSupported.GROUP.value, FusedMoeWeightScaleSupported.GROUP.value,
}) })
w13_qweight = Parameter(torch.empty(num_experts, w13_qweight = Parameter(
hidden_size, torch.empty(num_experts,
2 * intermediate_size // hidden_size,
self.quant_config.pack_factor, 2 * intermediate_size_per_partition //
dtype=torch.int32), self.quant_config.pack_factor,
requires_grad=False) dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight) layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs) set_weight_attrs(w13_qweight, extra_weight_attrs)
w2_qweight = Parameter(torch.empty(num_experts, w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size, intermediate_size_per_partition,
hidden_size // hidden_size //
self.quant_config.pack_factor, self.quant_config.pack_factor,
dtype=torch.int32), dtype=torch.int32),
...@@ -331,13 +332,14 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -331,13 +332,14 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qweight, extra_weight_attrs) set_weight_attrs(w2_qweight, extra_weight_attrs)
num_groups_w13 = hidden_size // self.quant_config.group_size num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size // self.quant_config.group_size num_groups_w2 = (intermediate_size_per_partition //
self.quant_config.group_size)
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. # Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts, w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13, num_groups_w13,
intermediate_size * 2, intermediate_size_per_partition * 2,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_scales", w13_scales) layer.register_parameter("w13_scales", w13_scales)
...@@ -353,12 +355,13 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -353,12 +355,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
# WEIGHT_ZERO_POINT # WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively. # Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(num_experts, w13_qzeros = Parameter(
num_groups_w13, torch.empty(num_experts,
2 * intermediate_size // num_groups_w13,
self.quant_config.pack_factor, 2 * intermediate_size_per_partition //
dtype=torch.int32), self.quant_config.pack_factor,
requires_grad=False) dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros) layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs) set_weight_attrs(w13_qzeros, extra_weight_attrs)
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS) WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -75,24 +76,26 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -75,24 +76,26 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.static_input_scales = not self.input_quant.dynamic self.static_input_scales = not self.input_quant.dynamic
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts, w13_weight = torch.nn.Parameter(torch.empty(
2 * intermediate_size, num_experts,
hidden_size, 2 * intermediate_size_per_partition,
dtype=params_dtype), hidden_size,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts, w2_weight = torch.nn.Parameter(torch.empty(
hidden_size, num_experts,
intermediate_size, hidden_size,
dtype=params_dtype), intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
...@@ -254,6 +257,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -254,6 +257,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self.packed_factor = 32 // config.num_bits self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy self.strategy = config.strategy
self.group_size = config.group_size self.group_size = config.group_size
self.actorder = config.actorder
assert config.symmetric, ( assert config.symmetric, (
"Only symmetric quantization is supported for MoE") "Only symmetric quantization is supported for MoE")
...@@ -266,9 +270,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -266,9 +270,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
f"{WNA16_SUPPORTED_BITS}") f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
assert params_dtype == torch.float16, (
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
)
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")
# Will transpose the loaded weight along the # Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will # intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims # shard for TP along the transposed dims
...@@ -276,35 +287,45 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -276,35 +287,45 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
"is_transposed": True, "is_transposed": True,
"quant_method": self.strategy "quant_method": self.strategy
}) })
w13_weight = torch.nn.Parameter(torch.empty(num_experts, w13_weight = torch.nn.Parameter(torch.empty(
hidden_size // num_experts,
self.packed_factor, hidden_size // self.packed_factor,
2 * intermediate_size, 2 * intermediate_size_per_partition,
dtype=torch.int32), dtype=torch.int32),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight) layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts, w2_weight = torch.nn.Parameter(torch.empty(
intermediate_size // num_experts,
self.packed_factor, intermediate_size_per_partition // self.packed_factor,
hidden_size, hidden_size,
dtype=torch.int32), dtype=torch.int32),
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight) layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
load_full_w2 = self.actorder and self.group_size != -1
w2_scales_size = (intermediate_size_full
if load_full_w2 else intermediate_size_per_partition)
self.is_k_full = (not self.actorder) or (
intermediate_size_per_partition == intermediate_size_full)
if self.strategy == "channel": if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1 num_groups_w2 = num_groups_w13 = 1
self.group_size = -1 self.group_size = -1
else: else:
num_groups_w2 = intermediate_size // self.group_size num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(num_experts, w13_scale = torch.nn.Parameter(torch.ones(
num_groups_w13, num_experts,
2 * intermediate_size, num_groups_w13,
dtype=params_dtype), 2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale) layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs) set_weight_attrs(w13_scale, extra_weight_attrs)
...@@ -316,6 +337,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -316,6 +337,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale) layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs) set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False) requires_grad=False)
...@@ -335,18 +357,18 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -335,18 +357,18 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
), ),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_g_idx", w13_g_idx) layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs) set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter( w2_g_idx = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
intermediate_size, intermediate_size_per_partition,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_g_idx", w2_g_idx) layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs) set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter( w13_g_idx_sort_indices = torch.nn.Parameter(
...@@ -364,7 +386,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -364,7 +386,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
w2_g_idx_sort_indices = torch.nn.Parameter( w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
intermediate_size, intermediate_size_per_partition,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -422,24 +444,55 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -422,24 +444,55 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
size_k2 = layer.w2_weight_packed.shape[2] size_k2 = layer.w2_weight_packed.shape[2]
size_k13 = layer.w13_weight_packed.shape[2] size_k13 = layer.w13_weight_packed.shape[2]
num_experts = layer.w13_g_idx.shape[0] num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_g_idx.device device = layer.w13_weight_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device), # when running models with grouped act order,
requires_grad=False, # resort to g_idx values provided in checkpoint
) if self.actorder == "group":
layer.w2_g_idx = torch.nn.Parameter( w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
torch.empty((num_experts, 0), dtype=torch.int32, device=device), w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
requires_grad=False, w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
) w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device), for e in range(num_experts):
requires_grad=False, w13_g_idx_sort_indices[e] = torch.argsort(
) layer.w13_weight_g_idx[e]).to(torch.int32)
layer.w2_g_idx_sort_indices = torch.nn.Parameter( w2_g_idx_sort_indices[e] = torch.argsort(
torch.empty((num_experts, 0), dtype=torch.int32, device=device), layer.w2_weight_g_idx[e]).to(torch.int32)
requires_grad=False, w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
) w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][
w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
else:
layer.w13_weight_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_weight_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.gptq_marlin_moe_repack( marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_weight_packed, layer.w13_weight_packed,
...@@ -511,9 +564,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -511,9 +564,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
g_idx1=layer.w13_g_idx, g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_g_idx, g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits, num_bits=self.num_bits,
) is_k_full=self.is_k_full)
...@@ -62,7 +62,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -62,7 +62,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
**kwargs): **kwargs):
assert params_dtype == torch.float16, ( assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501 "float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
) )
pack_factor = 32 // self.quant_type.size_bits pack_factor = 32 // self.quant_type.size_bits
......
...@@ -52,7 +52,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -52,7 +52,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
int8_dtype = torch.int8 int8_dtype = torch.int8
...@@ -64,26 +64,29 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -64,26 +64,29 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
extra_weight_attrs['weight_loader'] = wrapped_weight_loader extra_weight_attrs['weight_loader'] = wrapped_weight_loader
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts, w13_weight = torch.nn.Parameter(torch.empty(
2 * intermediate_size, num_experts,
hidden_size, 2 * intermediate_size_per_partition,
dtype=int8_dtype), hidden_size,
dtype=int8_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts, w2_weight = torch.nn.Parameter(torch.empty(
hidden_size, num_experts,
intermediate_size, hidden_size,
dtype=int8_dtype), intermediate_size_per_partition,
dtype=int8_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
w13_scale = torch.nn.Parameter(torch.zeros(num_experts, w13_scale = torch.nn.Parameter(torch.zeros(
2 * intermediate_size, num_experts,
dtype=torch.float32), 2 * intermediate_size_per_partition,
dtype=torch.float32),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_scale", w13_scale) layer.register_parameter("w13_scale", w13_scale)
......
...@@ -386,8 +386,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -386,8 +386,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size: int, params_dtype: torch.dtype, intermediate_size_per_partition: int,
**extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
...@@ -402,30 +402,34 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -402,30 +402,34 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# scales, the output_size of the weights for both the gate and up # scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n. # layers must be divisible by block_n.
# Required by column parallel or enabling merged weights # Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0: if intermediate_size_per_partition % block_n != 0:
raise ValueError( raise ValueError(
f"The output_size of gate's and up's weight = " f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by " f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}.") f"weight quantization block_n = {block_n}.")
if (tp_size > 1 and intermediate_size % block_k != 0): if (tp_size > 1
and intermediate_size_per_partition % block_k != 0):
# Required by row parallel # Required by row parallel
raise ValueError(f"The input_size of down's weight = " raise ValueError(
f"{intermediate_size} is not divisible by " f"The input_size of down's weight = "
f"weight quantization block_k = {block_k}.") f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts, w13_weight = torch.nn.Parameter(torch.empty(
2 * intermediate_size, num_experts,
hidden_size, 2 * intermediate_size_per_partition,
dtype=params_dtype), hidden_size,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts, w2_weight = torch.nn.Parameter(torch.empty(
hidden_size, num_experts,
intermediate_size, hidden_size,
dtype=params_dtype), intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
...@@ -446,7 +450,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -446,7 +450,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, num_experts,
2 * ((intermediate_size + block_n - 1) // block_n), 2 * ((intermediate_size_per_partition + block_n - 1) //
block_n),
(hidden_size + block_k - 1) // block_k, (hidden_size + block_k - 1) // block_k,
dtype=torch.float32, dtype=torch.float32,
), ),
...@@ -456,7 +461,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -456,7 +461,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
torch.ones( torch.ones(
num_experts, num_experts,
(hidden_size + block_n - 1) // block_n, (hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k, (intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32, dtype=torch.float32,
), ),
requires_grad=False, requires_grad=False,
......
...@@ -317,7 +317,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -317,7 +317,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
...@@ -326,7 +326,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -326,7 +326,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
# Supports only sym for now (no zp) # Supports only sym for now (no zp)
if self.quant_config.group_size != -1: if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size scales_size13 = hidden_size // self.quant_config.group_size
scales_size2 = intermediate_size // self.quant_config.group_size scales_size2 = (intermediate_size_per_partition //
self.quant_config.group_size)
strategy = FusedMoeWeightScaleSupported.GROUP.value strategy = FusedMoeWeightScaleSupported.GROUP.value
else: else:
scales_size13 = 1 scales_size13 = 1
...@@ -342,7 +343,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -342,7 +343,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty( torch.empty(
num_experts, num_experts,
hidden_size // self.quant_config.pack_factor, hidden_size // self.quant_config.pack_factor,
2 * intermediate_size, 2 * intermediate_size_per_partition,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -353,7 +354,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -353,7 +354,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_qweight = torch.nn.Parameter( w2_qweight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
intermediate_size // self.quant_config.pack_factor, intermediate_size_per_partition //
self.quant_config.pack_factor,
hidden_size, hidden_size,
dtype=torch.int32, dtype=torch.int32,
), ),
...@@ -365,7 +367,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -365,7 +367,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_scales = torch.nn.Parameter( w13_scales = torch.nn.Parameter(
torch.empty(num_experts, torch.empty(num_experts,
scales_size13, scales_size13,
2 * intermediate_size, 2 * intermediate_size_per_partition,
dtype=torch.half), dtype=torch.half),
requires_grad=False, requires_grad=False,
) )
...@@ -385,7 +387,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -385,7 +387,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_qzeros = torch.nn.Parameter( w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts, torch.empty(num_experts,
scales_size13, scales_size13,
2 * intermediate_size // self.quant_config.pack_factor, 2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False, requires_grad=False,
) )
...@@ -414,7 +417,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -414,7 +417,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx = torch.nn.Parameter( w2_g_idx = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
intermediate_size, intermediate_size_per_partition,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -435,7 +438,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -435,7 +438,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx_sort_indices = torch.nn.Parameter( w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
intermediate_size, intermediate_size_per_partition,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
......
...@@ -60,24 +60,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -60,24 +60,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
self.static_input_scales = not self.input_quant.get("is_dynamic") self.static_input_scales = not self.input_quant.get("is_dynamic")
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts, w13_weight = torch.nn.Parameter(torch.empty(
2 * intermediate_size, num_experts,
hidden_size, 2 * intermediate_size_per_partition,
dtype=params_dtype), hidden_size,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts, w2_weight = torch.nn.Parameter(torch.empty(
hidden_size, num_experts,
intermediate_size, hidden_size,
dtype=params_dtype), intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment