"vscode:/vscode.git/clone" did not exist on "7d6b03381e236e592eb814ab10b6ccfadaeed610"
Unverified Commit 9152a30d authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor][12/N] Marlin Fp8 MoE Pure Function (#31499)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent c2ff33cc
...@@ -964,12 +964,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -964,12 +964,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin: elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin( (
layer, False, input_dtype=self.marlin_input_dtype workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
) = prepare_moe_fp8_layer_for_marlin(
layer,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
input_dtype=self.marlin_input_dtype,
) )
# Activations not quantized for marlin. layer.workspace = workspace
del layer.w13_input_scale replace_parameter(layer, "w13_weight", w13_weight)
del layer.w2_input_scale replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
if self.use_cutlass: if self.use_cutlass:
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
......
...@@ -912,6 +912,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -912,6 +912,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight w13_weight, w2_weight
) )
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
(
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
) = prepare_moe_fp8_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
input_dtype=self.marlin_input_dtype,
)
layer.workspace = workspace
elif self.fp8_backend in [ elif self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS, Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM, Fp8MoeBackend.FLASHINFER_TRTLLM,
...@@ -937,17 +954,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -937,17 +954,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale) replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
# TODO(rob): we do this after replace_parameter() because
# prepare_moe_fp8_layer_for_marlin uses on the layer's params
# directly. We will refactor this in a follow up PR.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
def _setup_kernel(self, layer: Module) -> None: def _setup_kernel(self, layer: Module) -> None:
"""Setup Modular Kernel for TP Case""" """Setup Modular Kernel for TP Case"""
# NOTE(rob): this is a WIP refactor. We are first migrating # NOTE(rob): this is a WIP refactor. We are first migrating
...@@ -1194,20 +1200,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1194,20 +1200,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
if self.fp8_backend == Fp8MoeBackend.MARLIN: if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config( return fp8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale, w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=layer.w2_weight_scale, w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
) )
return fp8_w8a8_moe_quant_config( return fp8_w8a8_moe_quant_config(
w1_scale=( w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
layer.w13_weight_scale_inv w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
......
...@@ -315,10 +315,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -315,10 +315,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin: elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False) (workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = (
# Activations not quantized for marlin. prepare_moe_fp8_layer_for_marlin(
del layer.w13_input_scale layer,
del layer.w2_input_scale layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
)
)
layer.workspace = workspace
# TODO(rob): once we apply refactor to Quark, switch to using
# replace_parameter for compatibility with reloading in RL.
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
......
...@@ -199,9 +199,18 @@ def prepare_fp8_layer_for_marlin( ...@@ -199,9 +199,18 @@ def prepare_fp8_layer_for_marlin(
def prepare_moe_fp8_layer_for_marlin( def prepare_moe_fp8_layer_for_marlin(
layer: torch.nn.Module, layer: torch.nn.Module,
size_k_first: bool = True, w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
input_dtype: torch.dtype | None = None, input_dtype: torch.dtype | None = None,
) -> None: ) -> tuple[
torch.Tensor, # workspace
torch.Tensor, # w13_weight
torch.Tensor, # w2_weight
torch.Tensor, # w13_weight_scale
torch.Tensor, # w2_weight_scale
]:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP8 computation but " "Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will " "FP8 quantization is being used. Weight-only FP8 compression will "
...@@ -209,7 +218,7 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -209,7 +218,7 @@ def prepare_moe_fp8_layer_for_marlin(
"performance for compute-heavy workloads." "performance for compute-heavy workloads."
) )
if input_dtype is not None and input_dtype.itemsize == 1: if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("Marlin W8A8 is not supported.") raise NotImplementedError("Marlin W8A8 is not supported.")
e = layer.num_experts e = layer.num_experts
k = layer.hidden_size k = layer.hidden_size
...@@ -218,27 +227,22 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -218,27 +227,22 @@ def prepare_moe_fp8_layer_for_marlin(
# WORKSPACE # WORKSPACE
device = layer.w13_weight.device device = layer.w13_weight.device
layer.workspace = marlin_make_workspace_new(device, 4) workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device) perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT # WEIGHT
# Repack weights to marlin format # Repack weights to marlin format
for name in ["w13_weight", "w2_weight"]: def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
weight = getattr(layer, name)
tensor_list = [] tensor_list = []
if "w13" in name: if "w13" in name:
size_n, size_k = n * 2, k size_n, size_k = n * 2, k
else: else:
size_n, size_k = k, n size_n, size_k = k, n
if size_k_first:
assert weight.shape == (e, size_k, size_n)
else:
assert weight.shape == (e, size_n, size_k) assert weight.shape == (e, size_n, size_k)
for i in range(e): for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first) qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
if not size_k_first:
qweight = qweight.T.contiguous() qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
...@@ -246,25 +250,17 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -246,25 +250,17 @@ def prepare_moe_fp8_layer_for_marlin(
) )
tensor_list.append(marlin_qweight) tensor_list.append(marlin_qweight)
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
weight = torch.nn.Parameter(weight, requires_grad=False)
setattr(layer, name, weight) w13_weight = repack_weight("w13", w13_weight)
w2_weight = repack_weight("w2", w2_weight)
# WEIGHT SCALES # WEIGHT SCALES
# Permute scales # Permute scales
group_size = -1 if weight_block_size is None else weight_block_size[1] group_size = -1 if weight_block_size is None else weight_block_size[1]
for name in ["w13", "w2"]: def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
if name + "_weight_scale" in dir(layer): scales = scales.to(layer.orig_dtype)
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif name + "_weight_scale_inv" in dir(layer):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
tensor_list = [] tensor_list = []
if "w13" in name: if "w13" in name:
size_n, size_k = n * 2, k size_n, size_k = n * 2, k
...@@ -294,7 +290,6 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -294,7 +290,6 @@ def prepare_moe_fp8_layer_for_marlin(
# block-wise quantization -> group-wise quantization # block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0])) # (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n) # =>(repeat)=> (e, size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.permute(0, 2, 1) scales = scales.permute(0, 2, 1)
block_n = weight_block_size[0] block_n = weight_block_size[0]
scales = scales.repeat_interleave(block_n, 2) scales = scales.repeat_interleave(block_n, 2)
...@@ -310,26 +305,18 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -310,26 +305,18 @@ def prepare_moe_fp8_layer_for_marlin(
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
if input_dtype != torch.float8_e4m3fn: if input_dtype != torch.float8_e4m3fn:
scales = fp8_fused_exponent_bias_into_scales(scales) scales = fp8_fused_exponent_bias_into_scales(scales)
scales = torch.nn.Parameter(scales, requires_grad=False) return scales
setattr(layer, name + "_weight_scale", scales)
# BIAS w13_weight_scale = permute_scales(w13_weight_scale, "w13")
# Permute bias w2_weight_scale = permute_scales(w2_weight_scale, "w2")
for name in ["w13_bias", "w2_bias"]:
if not hasattr(layer, name):
continue
bias = getattr(layer, name).to(layer.orig_dtype)
tensor_list = [] return (
for i in range(e): workspace,
expert_bias = bias[i] w13_weight,
w2_weight,
tensor_list.append(marlin_permute_bias(expert_bias)) w13_weight_scale,
w2_weight_scale,
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) )
bias = torch.nn.Parameter(bias, requires_grad=False)
setattr(layer, name, bias)
def pack_fp8_to_int32( def pack_fp8_to_int32(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment