Unverified Commit 136825de authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Enhance code formatting in mxfp4.py (#22423)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent c2dba2db
...@@ -109,55 +109,74 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -109,55 +109,74 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.intermediate_size = intermediate_size_per_partition_after_pad self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.zeros( w13_weight = torch.nn.Parameter(
num_experts, torch.zeros(
2 * intermediate_size_per_partition_after_pad, num_experts,
hidden_size // 2, 2 * intermediate_size_per_partition_after_pad,
dtype=weight_dtype), hidden_size // 2,
requires_grad=False) dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(torch.zeros( w13_weight_scale = torch.nn.Parameter(
num_experts, torch.zeros(
2 * intermediate_size_per_partition_after_pad, num_experts,
hidden_size // mxfp4_block, 2 * intermediate_size_per_partition_after_pad,
dtype=scale_dtype), hidden_size // mxfp4_block,
requires_grad=False) dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_bias = torch.nn.Parameter(torch.zeros( w13_bias = torch.nn.Parameter(
num_experts, torch.zeros(
2 * intermediate_size_per_partition_after_pad, num_experts,
dtype=torch.bfloat16), 2 * intermediate_size_per_partition_after_pad,
requires_grad=False) dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias) layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs) set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.zeros( w2_weight = torch.nn.Parameter(
num_experts, torch.zeros(
hidden_size, num_experts,
intermediate_size_per_partition_after_pad // 2, hidden_size,
dtype=weight_dtype), intermediate_size_per_partition_after_pad // 2,
requires_grad=False) dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(torch.zeros( w2_weight_scale = torch.nn.Parameter(
num_experts, torch.zeros(
hidden_size, num_experts,
intermediate_size_per_partition_after_pad // mxfp4_block, hidden_size,
dtype=scale_dtype), intermediate_size_per_partition_after_pad // mxfp4_block,
requires_grad=False) dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_bias = torch.nn.Parameter(torch.zeros(num_experts, w2_bias = torch.nn.Parameter(
hidden_size, torch.zeros(
dtype=torch.bfloat16), num_experts,
requires_grad=False) hidden_size,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias) layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs) set_weight_attrs(w2_bias, extra_weight_attrs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment