Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1c8f379
Unverified
Commit
a1c8f379
authored
Mar 11, 2025
by
Jeff Daily
Committed by
GitHub
Mar 11, 2025
Browse files
dynamic distpatch of fp8 kernels (#14245)
Signed-off-by:
Jeff Daily
<
jeff.daily@amd.com
>
parent
08a1a112
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
56 additions
and
13 deletions
+56
-13
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+2
-10
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+4
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+30
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+17
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+3
-3
No files found.
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
a1c8f379
...
@@ -22,10 +22,6 @@ from vllm.utils import direct_register_custom_op
...
@@ -22,10 +22,6 @@ from vllm.utils import direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
current_platform_fp8_dtype
=
(
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
)
def
is_fp8
(
x
:
Union
[
torch
.
dtype
,
torch
.
Tensor
])
->
bool
:
def
is_fp8
(
x
:
Union
[
torch
.
dtype
,
torch
.
Tensor
])
->
bool
:
if
isinstance
(
x
,
torch
.
Tensor
):
if
isinstance
(
x
,
torch
.
Tensor
):
...
@@ -165,9 +161,7 @@ def input_to_float8(
...
@@ -165,9 +161,7 @@ def input_to_float8(
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""This function quantizes input values to float8 values "
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
"with tensor-wise quantization."""
if
dtype
is
None
:
dtype
=
current_platform
.
fp8_dtype
()
if
dtype
is
None
else
dtype
dtype
=
(
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
dtype
)
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
...
@@ -311,9 +305,7 @@ def per_token_group_quant_fp8(
...
@@ -311,9 +305,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
scaling factor for quantization.
"""
"""
if
dtype
is
None
:
dtype
=
current_platform
.
fp8_dtype
()
if
dtype
is
None
else
dtype
dtype
=
(
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
)
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
(
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
(
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"by `group_size`
{
group_size
}
"
)
f
"by `group_size`
{
group_size
}
"
)
...
...
vllm/platforms/cuda.py
View file @
a1c8f379
...
@@ -293,6 +293,10 @@ class CudaPlatformBase(Platform):
...
@@ -293,6 +293,10 @@ class CudaPlatformBase(Platform):
def
get_device_communicator_cls
(
cls
)
->
str
:
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"
# noqa
return
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"
# noqa
@
classmethod
def
supports_fp8
(
cls
)
->
bool
:
return
cls
.
has_device_capability
(
89
)
# NVML utils
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
...
...
vllm/platforms/interface.py
View file @
a1c8f379
...
@@ -330,6 +330,36 @@ class Platform:
...
@@ -330,6 +330,36 @@ class Platform:
"""
"""
return
"vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"
# noqa
return
"vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"
# noqa
@
classmethod
def
supports_fp8
(
cls
)
->
bool
:
"""
Returns whether the current platform supports FP8 types.
"""
return
False
@
classmethod
def
is_fp8_fnuz
(
cls
)
->
bool
:
"""
Returns whether the preferred FP8 type is FNUZ on the current platform.
There are two representations of FP8, OCP FP8 and FNUZ FP8.
The OCP specification can be found at https://tinyurl.com/b7jvwpft.
The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.
AMD's MI300 and MI325 have native hardware support for FNUZ. All other
hardware has converged on the OCP FP8 standard.
"""
return
False
@
classmethod
def
fp8_dtype
(
cls
)
->
torch
.
dtype
:
"""
Returns the preferred FP8 type on the current platform.
See the documentation for is_fp8_fnuz for details.
"""
return
torch
.
float8_e4m3fn
@
classmethod
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
def
use_all_gather
(
cls
)
->
bool
:
"""
"""
...
...
vllm/platforms/rocm.py
View file @
a1c8f379
...
@@ -231,3 +231,20 @@ class RocmPlatform(Platform):
...
@@ -231,3 +231,20 @@ class RocmPlatform(Platform):
@
classmethod
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"
# noqa
return
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"
# noqa
@
classmethod
def
supports_fp8
(
cls
)
->
bool
:
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
return
any
(
gfx
in
gcn_arch
for
gfx
in
[
'gfx94'
,
'gfx95'
,
'gfx12'
])
@
classmethod
def
is_fp8_fnuz
(
cls
)
->
bool
:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return
'gfx94'
in
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
@
classmethod
def
fp8_dtype
(
cls
)
->
torch
.
dtype
:
if
cls
.
is_fp8_fnuz
():
return
torch
.
float8_e4m3fnuz
else
:
return
torch
.
float8_e4m3fn
vllm/v1/attention/backends/mla/common.py
View file @
a1c8f379
...
@@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8
)
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
Fp8LinearGenericOp
,
current_platform_fp8_dtype
,
is_fp8
)
Fp8LinearGenericOp
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
...
@@ -826,7 +826,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -826,7 +826,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
W_Q_UK
,
self
.
reqaunt_weight_group_shape
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
_
fp8_dtype
)
quant_dtype
=
current_platform
.
fp8_dtype
()
)
# For FP8 save the transpose so we can use
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
# `apply_w8a8_block_fp8_linear` directly
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
...
@@ -843,7 +843,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -843,7 +843,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
W_UV_O
,
self
.
reqaunt_weight_group_shape
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
_
fp8_dtype
)
quant_dtype
=
current_platform
.
fp8_dtype
()
)
# For FP8 save the transpose so we can use
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
# `apply_w8a8_block_fp8_linear` directly
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment