Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e1744502
Unverified
Commit
e1744502
authored
Mar 07, 2025
by
Luka Govedič
Committed by
GitHub
Mar 07, 2025
Browse files
[FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object (#14390)
Signed-off-by:
luka
<
luka@neuralmagic.com
>
parent
dae68969
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
268 additions
and
242 deletions
+268
-242
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+7
-13
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+4
-3
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+8
-11
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+9
-13
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+10
-11
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+7
-9
vllm/model_executor/layers/quantization/ptpc_fp8.py
vllm/model_executor/layers/quantization/ptpc_fp8.py
+9
-9
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
...cutor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+7
-11
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+54
-38
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+149
-121
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+4
-3
No files found.
tests/compile/test_fusion.py
View file @
e1744502
...
...
@@ -13,7 +13,7 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_FP8_SUPPORTED
,
apply_fp8_l
inear
,
maybe_create_device_identity
)
CUTLASS_FP8_SUPPORTED
,
Fp8L
inear
Op
,
maybe_create_device_identity
)
from
.backend
import
TestBackend
...
...
@@ -34,26 +34,20 @@ class TestModel(torch.nn.Module):
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
for
_
in
range
(
2
)
]
self
.
fp8_linear
=
Fp8LinearOp
(
cutlass_fp8_supported
=
cutlass_fp8_enabled
,
use_per_token_if_dynamic
=
True
)
def
forward
(
self
,
x
):
resid
=
torch
.
sqrt
(
x
)
y
=
self
.
norm
[
0
](
x
)
x2
=
apply_fp8_linear
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
self
.
scale
[
0
],
use_per_token_if_dynamic
=
True
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_enabled
)
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
self
.
scale
[
0
])
# make sure resid is used for replacement to work
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
x3
=
apply_fp8_linear
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
self
.
scale
[
1
],
use_per_token_if_dynamic
=
True
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_enabled
)
x3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
self
.
scale
[
1
])
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
return
y3
...
...
vllm/attention/backends/mla/common.py
View file @
e1744502
...
...
@@ -226,7 +226,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_fp8_l
inear
_g
eneric
,
current_platform_fp8_dtype
,
is_fp8
)
Fp8L
inear
G
eneric
Op
,
current_platform_fp8_dtype
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
(
...
...
@@ -1057,6 +1057,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
triton_fa_func
=
triton_attention
self
.
fp8_linear_generic
=
Fp8LinearGenericOp
()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
...
...
@@ -1071,7 +1072,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def
_v_up_proj_and_o_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_UV_O
):
output_parallel
=
apply_
fp8_linear_generic
(
output_parallel
=
self
.
fp8_linear_generic
.
apply
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
,
self
.
W_UV_O_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
)
...
...
@@ -1091,7 +1092,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def
_q_proj_and_k_up_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_Q_UK
):
return
apply_
fp8_linear_generic
(
return
self
.
fp8_linear_generic
.
apply
(
x
,
self
.
W_Q_UK
,
self
.
W_Q_UK_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
).
view
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
e1744502
...
...
@@ -9,8 +9,8 @@ from torch.nn import Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
Fp8LinearOp
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
...
...
@@ -24,7 +24,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
(
)
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
True
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
...
...
@@ -140,11 +140,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
True
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
e1744502
...
...
@@ -11,14 +11,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.fp8
import
cutlass_fp8_supported
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
)
Fp8LinearOp
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
from
vllm.platforms
import
current_platform
...
...
@@ -37,6 +35,7 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
not
current_platform
.
has_device_capability
(
89
)
self
.
fp8_linear
=
Fp8LinearOp
()
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -73,7 +72,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
(
)
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
True
)
def
create_weights
(
self
,
...
...
@@ -159,12 +158,9 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
True
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
)
vllm/model_executor/layers/quantization/fp8.py
View file @
e1744502
...
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
Fp8LinearOp
,
all_close_1d
,
convert_to_channelwise
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
requantize_with_max_scale
)
...
...
@@ -137,7 +137,6 @@ class Fp8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
...
...
@@ -153,6 +152,10 @@ class Fp8LinearMethod(LinearMethodBase):
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
self
.
fp8_linear
=
Fp8LinearOp
(
# Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic
=
cutlass_fp8_supported
())
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -381,15 +384,11 @@ class Fp8LinearMethod(LinearMethodBase):
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
# Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic
=
self
.
cutlass_fp8_supported
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
e1744502
...
...
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
Fp8LinearOp
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
...
...
@@ -95,7 +95,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
fp8_linear
=
Fp8LinearOp
()
def
create_weights
(
self
,
...
...
@@ -157,10 +157,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
vllm/model_executor/layers/quantization/ptpc_fp8.py
View file @
e1744502
...
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_l
inear
)
Fp8L
inear
Op
)
from
vllm.platforms
import
current_platform
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
...
@@ -93,6 +93,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
super
().
__init__
(
quant_config
=
quant_config
)
# Force weight quantization
self
.
quant_config
.
is_checkpoint_fp8_serialized
=
False
self
.
fp8_linear
=
Fp8LinearOp
(
cutlass_fp8_supported
=
False
,
use_per_token_if_dynamic
=
True
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
...
...
@@ -115,11 +117,9 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
None
,
bias
=
bias
,
cutlass_fp8_supported
=
False
,
use_per_token_if_dynamic
=
True
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
None
,
bias
=
bias
)
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
View file @
e1744502
...
...
@@ -7,8 +7,7 @@ from torch.nn import Parameter
from
vllm.model_executor.layers.quantization.quark.schemes
import
QuarkScheme
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
Fp8LinearOp
,
normalize_e4m3fn_to_e4m3fnuz
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
...
...
@@ -22,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
def
__init__
(
self
,
qscheme
:
str
,
is_static_input_scheme
:
Optional
[
bool
]):
self
.
qscheme
=
qscheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
(
)
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
True
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
...
...
@@ -132,11 +131,8 @@ class QuarkW8A8Fp8(QuarkScheme):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
True
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
e1744502
...
...
@@ -15,7 +15,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
_normalize_quant_group_shape
,
scaled_dequantize
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
CUTLASS_FP8_SUPPORTED
,
apply_fp8_linear
)
CUTLASS_BLOCK_FP8_SUPPORTED
,
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -32,6 +33,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
return
x
==
torch
.
float8_e4m3fn
or
x
==
torch
.
float8_e4m3fnuz
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def
apply_w8a8_block_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
@@ -49,6 +52,7 @@ def apply_w8a8_block_fp8_linear(
shape_supported_by_cutlass
=
(
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
if
current_platform
.
is_rocm
():
# TODO this is never used, as cutlass_block_fp8_supported is False
scale_a_shape
=
((
input_2d
.
shape
[
-
1
]
//
block_size
[
1
],
)
+
input_2d
.
shape
[:
-
1
])[::
-
1
]
scale_b_shape
=
(
weight_scale
.
view
(
-
1
,
1
)
...
...
@@ -104,43 +108,55 @@ direct_register_custom_op(
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
def
apply_fp8_linear_generic
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_group_shape
:
Tuple
[
int
,
int
],
weight_group_shape
:
Tuple
[
int
,
int
],
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
# static scale if one
cutlass_fp8_supported
:
bool
=
CUTLASS_FP8_SUPPORTED
,
cutlass_block_fp8_supported
:
bool
=
CUTLASS_BLOCK_FP8_SUPPORTED
,
)
->
torch
.
Tensor
:
# View input as 2D matrix for fp8 methods
input
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
weight_group_shape
=
_normalize_quant_group_shape
(
\
weight
,
weight_group_shape
)
input_group_shape
=
_normalize_quant_group_shape
(
input
,
input_group_shape
)
def
is_dim_blocked
(
dim
,
shape
,
group_shape
):
return
group_shape
<
shape
[
dim
]
and
group_shape
>
1
if
is_dim_blocked
(
0
,
weight
.
shape
,
weight_group_shape
[
0
])
\
and
is_dim_blocked
(
1
,
weight
.
shape
,
weight_group_shape
[
1
])
and
\
input_group_shape
==
(
1
,
weight_group_shape
[
1
]):
return
apply_w8a8_block_fp8_linear
(
input
,
weight
,
list
(
weight_group_shape
),
weight_scale
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
)
else
:
# Despite having linear in the it doesn't conform to
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
# so we explicitly transpose the weight matrix here
return
apply_fp8_linear
(
input
,
weight
.
T
,
weight_scale
.
T
,
cutlass_fp8_supported
=
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
\
(
input_group_shape
==
(
1
,
input
.
shape
[
1
])))
# TODO(luka): unify this better
# https://github.com/vllm-project/vllm/issues/14397
class
Fp8LinearGenericOp
:
def
__init__
(
self
,
cutlass_fp8_supported
:
bool
=
cutlass_fp8_supported
(),
cutlass_block_fp8_supported
:
bool
=
cutlass_block_fp8_supported
(),
):
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
self
.
fp8_linear
=
Fp8LinearOp
(
cutlass_fp8_supported
=
cutlass_fp8_supported
)
def
apply
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_group_shape
:
Tuple
[
int
,
int
],
weight_group_shape
:
Tuple
[
int
,
int
],
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
# static scale if one
)
->
torch
.
Tensor
:
# View input as 2D matrix for fp8 methods
input
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
weight_group_shape
=
_normalize_quant_group_shape
(
\
weight
,
weight_group_shape
)
input_group_shape
=
_normalize_quant_group_shape
(
input
,
input_group_shape
)
def
is_dim_blocked
(
dim
,
shape
,
group_shape
):
return
group_shape
<
shape
[
dim
]
and
group_shape
>
1
if
is_dim_blocked
(
0
,
weight
.
shape
,
weight_group_shape
[
0
])
\
and
is_dim_blocked
(
1
,
weight
.
shape
,
weight_group_shape
[
1
])
and
\
input_group_shape
==
(
1
,
weight_group_shape
[
1
]):
return
apply_w8a8_block_fp8_linear
(
input
,
weight
,
list
(
weight_group_shape
),
weight_scale
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
)
else
:
# Despite having linear in the name it doesn't conform to
# `torch.nn.functional.linear` which is defined as
# `input @ weight.T` so we explicitly transpose the weight matrix
return
self
.
fp8_linear
.
apply
(
input
,
weight
.
T
,
weight_scale
.
T
,
use_per_token_if_dynamic
=
\
(
input_group_shape
==
(
1
,
input
.
shape
[
1
])))
def
input_to_float8
(
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
e1744502
...
...
@@ -121,134 +121,162 @@ def maybe_create_device_identity():
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
def
apply_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cutlass_fp8_supported
:
bool
=
CUTLASS_FP8_SUPPORTED
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
1
]]
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
scale_ub
=
input_scale_ub
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else
:
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class
Fp8LinearOp
:
"""
This class executes a FP8 linear layer using cutlass if supported and
torch.scaled_mm otherwise.
It needs to be a class instead of a method so that config can be read
in the __init__ method, as reading config is not allowed inside forward.
"""
def
__init__
(
self
,
cutlass_fp8_supported
:
bool
=
cutlass_fp8_supported
(),
use_per_token_if_dynamic
:
bool
=
False
,
pad_output
:
Optional
[
bool
]
=
None
):
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
self
.
use_per_token_if_dynamic
=
use_per_token_if_dynamic
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config
=
get_current_vllm_config
().
compilation_config
do_pad
=
config
.
level
<
CompilationLevel
.
PIECEWISE
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
num_token_padding
=
17
if
do_pad
else
None
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
elif
(
use_per_token_if_dynamic
and
not
per_tensor_weights
and
not
per_tensor_activations
and
USE_ROWWISE_TORCH_SCALED_MM
):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
.
t
(),
bias
=
bias
)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
output
=
output
.
view
(
*
output_shape
)
return
output
if
pad_output
is
None
:
config
=
get_current_vllm_config
().
compilation_config
pad_output
=
config
.
level
<
CompilationLevel
.
PIECEWISE
self
.
output_padding
=
17
if
pad_output
else
None
def
apply
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
input_scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
# TODO(luka) remove this parameter in favor of __init__
use_per_token_if_dynamic
:
Optional
[
bool
]
=
None
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
1
]]
# TODO(luka) this is here because currently MLA only decides this
# during the forward method instead of in __init__.
if
use_per_token_if_dynamic
is
None
:
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
self
.
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
scale_ub
=
input_scale_ub
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else
:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
out_dtype
=
torch
.
float32
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input_2d
.
shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
weight_scale
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
# Maybe apply padding to output, see comment in __init__
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_2d
,
input_scale
,
num_token_padding
=
self
.
output_padding
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
elif
(
use_per_token_if_dynamic
and
not
per_tensor_weights
and
not
per_tensor_activations
and
USE_ROWWISE_TORCH_SCALED_MM
):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
.
t
(),
bias
=
bias
)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
output
=
output
.
view
(
*
output_shape
)
return
output
else
:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
out_dtype
=
torch
.
float32
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input_2d
.
shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
weight_scale
.
t
()
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
def
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/v1/attention/backends/mla/common.py
View file @
e1744502
...
...
@@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_fp8_l
inear
_g
eneric
,
current_platform_fp8_dtype
,
is_fp8
)
Fp8L
inear
G
eneric
Op
,
current_platform_fp8_dtype
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
...
...
@@ -640,6 +640,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
self
.
fp8_linear_generic
=
Fp8LinearGenericOp
()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
...
...
@@ -653,7 +654,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
def
_v_up_proj_and_o_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_UV_O
):
output_parallel
=
apply_
fp8_linear_generic
(
output_parallel
=
self
.
fp8_linear_generic
.
apply
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
,
self
.
W_UV_O_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
)
...
...
@@ -673,7 +674,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
def
_q_proj_and_k_up_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_Q_UK
):
return
apply_
fp8_linear_generic
(
return
self
.
fp8_linear_generic
.
apply
(
x
,
self
.
W_Q_UK
,
self
.
W_Q_UK_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
).
view
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment