Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d7a299ed
Unverified
Commit
d7a299ed
authored
Jul 30, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jul 30, 2024
Browse files
[Kernel] Remove scaled_fp8_quant kernel padding footgun (#6842)
parent
052b6f8c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
14 deletions
+17
-14
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+13
-11
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+3
-2
No files found.
tests/quantization/test_fp8.py
View file @
d7a299ed
...
@@ -123,7 +123,7 @@ def test_scaled_fp8_quant(dtype) -> None:
...
@@ -123,7 +123,7 @@ def test_scaled_fp8_quant(dtype) -> None:
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
# Padding
# Padding
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
,
batch_dim
_padding
=
17
)
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
,
num_token
_padding
=
17
)
assert
y
.
shape
[
0
]
==
17
assert
y
.
shape
[
0
]
==
17
assert
torch
.
allclose
(
assert
torch
.
allclose
(
ref_y
,
ref_y
,
...
...
vllm/_custom_ops.py
View file @
d7a299ed
...
@@ -307,7 +307,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -307,7 +307,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
def
scaled_fp8_quant
(
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_dim
_padding
:
Optional
[
int
]
=
None
,
num_token
_padding
:
Optional
[
int
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -317,7 +317,7 @@ def scaled_fp8_quant(
...
@@ -317,7 +317,7 @@ def scaled_fp8_quant(
This function supports both static and dynamic quantization: If you
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
the scale will be determined dynamically. The function also allows
optional padding of the output tensor for downstream kernels that
optional padding of the output tensor
s
for downstream kernels that
will benefit from padding.
will benefit from padding.
Args:
Args:
...
@@ -325,7 +325,7 @@ def scaled_fp8_quant(
...
@@ -325,7 +325,7 @@ def scaled_fp8_quant(
scale: Optional scaling factor for the FP8 quantization
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
per token case
batch_dim
_padding: If specified, pad the first dimension
num_token
_padding: If specified, pad the first dimension
of the output to at least this value.
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
in the dynamic quantization case.
...
@@ -334,16 +334,16 @@ def scaled_fp8_quant(
...
@@ -334,16 +334,16 @@ def scaled_fp8_quant(
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
scaling factor.
"""
"""
if
batch_dim_padding
:
# This code assumes batch_dim and num_tokens are flattened
shape
=
(
max
(
batch_dim_padding
,
input
.
shape
[
0
]),
*
input
.
shape
[
1
:]
)
assert
(
input
.
ndim
==
2
)
output
=
torch
.
empty
(
shape
,
shape
=
input
.
shape
device
=
input
.
device
,
if
num_token_padding
:
dtype
=
torch
.
float8_e4m3fn
)
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
]
)
else
:
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
torch
.
float8_e4m3fn
)
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
float8_e4m3fn
)
if
scale
is
None
:
if
scale
is
None
:
if
use_per_token_if_dynamic
:
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
(
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
(
...
@@ -352,6 +352,8 @@ def scaled_fp8_quant(
...
@@ -352,6 +352,8 @@ def scaled_fp8_quant(
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
else
:
# num_token_padding not implemented for this case
assert
(
scale
.
numel
()
==
1
or
num_token_padding
is
None
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
return
output
,
scale
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
d7a299ed
...
@@ -139,7 +139,7 @@ def apply_fp8_linear(
...
@@ -139,7 +139,7 @@ def apply_fp8_linear(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input
,
input_scale
,
input_scale
,
batch_dim
_padding
=
17
,
num_token
_padding
=
17
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
...
@@ -177,8 +177,9 @@ def apply_fp8_linear(
...
@@ -177,8 +177,9 @@ def apply_fp8_linear(
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
weight
,
out_dtype
=
torch
.
float32
)
out_dtype
=
torch
.
float32
)
# Unpad (undo
batch_dim
_padding)
# Unpad (undo
num_token
_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input
.
shape
[
0
])
# DQ
# DQ
# C = sw * sx * (X * W) + bias
# C = sw * sx * (X * W) + bias
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment