Unverified Commit d52358c1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Perf] Remove duplicated NVFP4 blockscales to save memory (#23379)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 6ace2f72
......@@ -246,11 +246,11 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return
# swizzle weight scales
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
layer.w13_weight_scale),
requires_grad=False)
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
layer.w2_weight_scale),
requires_grad=False)
......@@ -383,8 +383,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -406,8 +406,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
......@@ -427,8 +427,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
......
......@@ -112,12 +112,11 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale_swizzled = Parameter(weight_scale,
requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
layer.weight_scale = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
......@@ -136,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x=x,
input_global_scale=layer.input_global_scale,
weight=layer.weight_packed,
weight_scale_swizzled=layer.weight_scale_swizzled,
weight_scale_swizzled=layer.weight_scale,
weight_global_scale=layer.weight_global_scale)
if bias is not None:
out = out + bias
......@@ -149,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
layer.weight_scale_swizzled, layer.alpha, output_dtype)
layer.weight_scale, layer.alpha, output_dtype)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
......
......@@ -907,12 +907,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale_swizzled = Parameter(weight_scale,
requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
layer.weight_scale = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
......@@ -920,7 +919,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
del layer.weight_scale_swizzled
def apply(
self,
......@@ -951,14 +949,14 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert (x_fp4.dtype == torch.uint8)
assert (layer.weight.dtype == torch.uint8)
assert (x_blockscale.dtype == torch.float8_e4m3fn)
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
assert (layer.weight_scale.dtype == torch.float8_e4m3fn)
assert (layer.alpha.dtype == torch.float32)
mm_args = (
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale_swizzled,
layer.weight_scale,
layer.alpha,
output_dtype,
)
......@@ -1320,7 +1318,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Weight Blockscale must be represented as FP8-E4M3")
w13_blockscale_swizzled = swizzle_blockscale(
layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
layer.w13_weight_scale = Parameter(w13_blockscale_swizzled,
requires_grad=False)
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
......@@ -1328,7 +1326,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
layer.w2_weight_scale = Parameter(w2_blockscale_swizzled,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data,
requires_grad=False)
......@@ -1339,8 +1337,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
del layer.g2_alphas
del layer.w13_input_scale_quant
del layer.w2_input_scale_quant
del layer.w13_blockscale_swizzled
del layer.w2_blockscale_swizzled
def apply(
self,
......@@ -1474,8 +1470,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (self.allow_flashinfer
......@@ -1489,8 +1485,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
......@@ -1510,8 +1506,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment