Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f065de4e
Unverified
Commit
f065de4e
authored
May 12, 2025
by
Michael Goin
Committed by
GitHub
May 12, 2025
Browse files
Fix FBGEMM integration (#18002)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
dc990536
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
13 deletions
+11
-13
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+3
-1
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+8
-12
No files found.
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
f065de4e
...
@@ -63,7 +63,9 @@ class FBGEMMFp8Config(QuantizationConfig):
...
@@ -63,7 +63,9 @@ class FBGEMMFp8Config(QuantizationConfig):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignore_list
):
if
is_layer_skipped
(
prefix
=
prefix
,
ignored_layers
=
self
.
ignore_list
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
return
FBGEMMFp8LinearMethod
(
self
)
return
FBGEMMFp8LinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
f065de4e
...
@@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
part_size_n
=
layer
.
output_size_per_partition
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
weight_block_size
=
getattr
(
layer
,
"weight_block_size"
,
None
)
if
size_k_first
:
if
size_k_first
:
assert
layer
.
weight
.
shape
==
(
part_size_k
,
part_size_n
)
assert
layer
.
weight
.
shape
==
(
part_size_k
,
part_size_n
)
...
@@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
scales
=
layer
.
weight_scale_inv
.
to
(
layer
.
orig_dtype
)
scales
=
layer
.
weight_scale_inv
.
to
(
layer
.
orig_dtype
)
del
layer
.
weight_scale_inv
del
layer
.
weight_scale_inv
if
layer
.
weight_block_size
is
None
:
group_size
=
-
1
if
weight_block_size
is
None
else
weight_block_size
[
1
]
group_size
=
-
1
else
:
group_size
=
layer
.
weight_block_size
[
1
]
# marlin kernel only support channel-wise and group-wise quantization
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
# we need to convert the scales
if
layer
.
weight_block_size
is
None
:
if
weight_block_size
is
None
:
if
scales
.
nelement
()
==
1
:
if
scales
.
nelement
()
==
1
:
# tensor-wise quantization -> channel-wise quantization
# tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n)
# (1, 1) =>(repeat)=> (1, size_n)
...
@@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
# =>(repeat)=> (size_k // block_size[1], size_n)
# =>(repeat)=> (size_k // block_size[1], size_n)
if
not
size_k_first
:
if
not
size_k_first
:
scales
=
scales
.
T
.
contiguous
()
scales
=
scales
.
T
.
contiguous
()
block_n
=
layer
.
weight_block_size
[
0
]
block_n
=
weight_block_size
[
0
]
scales
=
scales
.
repeat_interleave
(
block_n
,
1
)
scales
=
scales
.
repeat_interleave
(
block_n
,
1
)
# size_n may not divisible by block_size[0]
# size_n may not divisible by block_size[0]
scales
=
scales
[:,
:
part_size_n
]
scales
=
scales
[:,
:
part_size_n
]
...
@@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
e
=
layer
.
num_experts
e
=
layer
.
num_experts
k
=
layer
.
hidden_size
k
=
layer
.
hidden_size
n
=
layer
.
intermediate_size_per_partition
n
=
layer
.
intermediate_size_per_partition
weight_block_size
=
getattr
(
layer
,
"weight_block_size"
,
None
)
# WORKSPACE
# WORKSPACE
device
=
layer
.
w13_weight
.
device
device
=
layer
.
w13_weight
.
device
...
@@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# WEIGHT SCALES
# WEIGHT SCALES
# Permute scales
# Permute scales
if
layer
.
weight_block_size
is
None
:
group_size
=
-
1
if
weight_block_size
is
None
else
weight_block_size
[
1
]
group_size
=
-
1
else
:
group_size
=
layer
.
weight_block_size
[
1
]
for
name
in
[
"w13"
,
"w2"
]:
for
name
in
[
"w13"
,
"w2"
]:
if
name
+
"_weight_scale"
in
dir
(
layer
):
if
name
+
"_weight_scale"
in
dir
(
layer
):
...
@@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# marlin kernel only support channel-wise and group-wise quantization
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
# we need to convert the scales
if
layer
.
weight_block_size
is
None
:
if
weight_block_size
is
None
:
if
scales
.
nelement
()
==
e
:
if
scales
.
nelement
()
==
e
:
# tensor-wise quantization -> channel-wise quantization
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
...
@@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# =>(repeat)=> (e, size_k // block_size[1], size_n)
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if
not
size_k_first
:
if
not
size_k_first
:
scales
=
scales
.
permute
(
0
,
2
,
1
)
scales
=
scales
.
permute
(
0
,
2
,
1
)
block_n
=
layer
.
weight_block_size
[
0
]
block_n
=
weight_block_size
[
0
]
scales
=
scales
.
repeat_interleave
(
block_n
,
2
)
scales
=
scales
.
repeat_interleave
(
block_n
,
2
)
# size_n may not divisible by block_size[0]
# size_n may not divisible by block_size[0]
scales
=
scales
[...,
:
size_n
].
contiguous
()
scales
=
scales
[...,
:
size_n
].
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment