Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c904fddd
Unverified
Commit
c904fddd
authored
Feb 22, 2025
by
Gregory Shtrasberg
Committed by
GitHub
Feb 22, 2025
Browse files
[ROCm] Apply FP8 weights padding to values not divisible by 512 bytes on ROCm (#13231)
parent
558db808
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
1 deletion
+20
-1
vllm/envs.py
vllm/envs.py
+4
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+15
-0
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+1
-1
No files found.
vllm/envs.py
View file @
c904fddd
...
@@ -74,6 +74,7 @@ if TYPE_CHECKING:
...
@@ -74,6 +74,7 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
False
VLLM_USE_V1
:
bool
=
False
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
VLLM_DISABLE_COMPILE_CACHE
:
bool
=
False
VLLM_DISABLE_COMPILE_CACHE
:
bool
=
False
...
@@ -507,6 +508,9 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -507,6 +508,9 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_V1"
:
"VLLM_USE_V1"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_V1"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_V1"
,
"0"
))),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ROCM_FP8_PADDING"
,
"1"
))),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT"
:
"K_SCALE_CONSTANT"
:
lambda
:
int
(
os
.
getenv
(
"K_SCALE_CONSTANT"
,
"200"
)),
lambda
:
int
(
os
.
getenv
(
"K_SCALE_CONSTANT"
,
"200"
)),
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
c904fddd
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -251,6 +252,17 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -251,6 +252,17 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
layer
.
register_parameter
(
"input_scale"
,
None
)
def
add_padding_to_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if
(
envs
.
VLLM_ROCM_FP8_PADDING
and
current_platform
.
is_rocm
()
and
weight
.
stride
(
-
1
)
==
1
and
(
weight
.
stride
(
-
2
)
*
weight
.
element_size
())
%
512
==
0
):
num_pad
=
256
//
weight
.
element_size
()
weight
=
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
torch
.
cuda
.
empty_cache
()
return
weight
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# TODO(rob): refactor block quant into separate class.
# TODO(rob): refactor block quant into separate class.
if
self
.
block_quant
:
if
self
.
block_quant
:
...
@@ -264,6 +276,8 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -264,6 +276,8 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
.
data
weight
=
layer
.
weight
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight
=
self
.
add_padding_to_weight
(
weight
)
# Torch.compile cannot use Parameter subclasses.
# Torch.compile cannot use Parameter subclasses.
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
...
@@ -327,6 +341,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -327,6 +341,7 @@ class Fp8LinearMethod(LinearMethodBase):
logical_widths
=
layer
.
logical_widths
,
logical_widths
=
layer
.
logical_widths
,
)
)
weight
=
self
.
add_padding_to_weight
(
weight
)
# Update layer with new values.
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
c904fddd
...
@@ -494,7 +494,7 @@ def w8a8_block_fp8_matmul(
...
@@ -494,7 +494,7 @@ def w8a8_block_fp8_matmul(
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
B
.
ndim
==
2
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment