Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ffffccb
Unverified
Commit
4ffffccb
authored
Jul 18, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jul 18, 2024
Browse files
[Kernel] Implement fallback for FP8 channelwise using torch._scaled_mm (#6552)
parent
f53b8f0d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
21 deletions
+40
-21
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+0
-11
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+40
-10
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
4ffffccb
...
...
@@ -23,16 +23,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# On Lovelace, fail for now if channelwise.
# TODO: (@tms) fallback
if
(
not
self
.
cutlass_fp8_supported
and
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
):
raise
ValueError
(
"Channelwise fp8 quantization requires vLLM's custom "
"cutlass kernels, which are not supported on your device."
"Consider quantizing with per tensor scales or upgrading "
"to Hopper."
)
def
get_min_capability
(
self
)
->
int
:
# lovelace and up
return
89
...
...
@@ -53,7 +43,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
# If channelwise, scales are already lined up, so just transpose.
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
assert
self
.
cutlass_fp8_supported
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
4ffffccb
...
...
@@ -124,20 +124,50 @@ def apply_fp8_linear(
bias
=
bias
)
else
:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
if
weight_scale
.
numel
()
==
1
:
# Fused GEMM_DQ
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
else
:
# Fallback for channelwise case, where the weight scales are
# applied separately.
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# This computes C = sx * (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
torch
.
float32
,
scale_a
=
x_scale
)
# C = sw * sx * (X * W)
output
=
output
*
weight_scale
.
t
()
if
bias
is
not
None
:
# C = sw * sx * (X * W) + bias
output
=
output
+
bias
output
=
output
.
to
(
dtype
=
input
.
dtype
)
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment