Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
136825de
Unverified
Commit
136825de
authored
Aug 07, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 07, 2025
Browse files
[Misc] Enhance code formatting in mxfp4.py (#22423)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
c2dba2db
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
33 deletions
+52
-33
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+52
-33
No files found.
vllm/model_executor/layers/quantization/mxfp4.py
View file @
136825de
...
@@ -109,55 +109,74 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -109,55 +109,74 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
.
intermediate_size
=
intermediate_size_per_partition_after_pad
self
.
intermediate_size
=
intermediate_size_per_partition_after_pad
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
# Fused gate_up_proj (column parallel)
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
num_experts
,
2
*
intermediate_size_per_partition_after_pad
,
2
*
intermediate_size_per_partition_after_pad
,
hidden_size
//
2
,
hidden_size
//
2
,
dtype
=
weight_dtype
),
dtype
=
weight_dtype
,
requires_grad
=
False
)
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
num_experts
,
2
*
intermediate_size_per_partition_after_pad
,
2
*
intermediate_size_per_partition_after_pad
,
hidden_size
//
mxfp4_block
,
hidden_size
//
mxfp4_block
,
dtype
=
scale_dtype
),
dtype
=
scale_dtype
,
requires_grad
=
False
)
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
num_experts
,
2
*
intermediate_size_per_partition_after_pad
,
2
*
intermediate_size_per_partition_after_pad
,
dtype
=
torch
.
bfloat16
),
dtype
=
torch
.
bfloat16
,
requires_grad
=
False
)
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
# down_proj (row parallel)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
intermediate_size_per_partition_after_pad
//
2
,
intermediate_size_per_partition_after_pad
//
2
,
dtype
=
weight_dtype
),
dtype
=
weight_dtype
,
requires_grad
=
False
)
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
intermediate_size_per_partition_after_pad
//
mxfp4_block
,
intermediate_size_per_partition_after_pad
//
mxfp4_block
,
dtype
=
scale_dtype
),
dtype
=
scale_dtype
,
requires_grad
=
False
)
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
hidden_size
,
dtype
=
torch
.
bfloat16
),
dtype
=
torch
.
bfloat16
,
requires_grad
=
False
)
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment