Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4cc24f01
Unverified
Commit
4cc24f01
authored
Jul 19, 2024
by
Robert Shaw
Committed by
GitHub
Jul 19, 2024
Browse files
[ Kernel ] Enable Dynamic Per Token `fp8` (#6547)
parent
07eb6f19
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
68 additions
and
39 deletions
+68
-39
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
...a-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-0
tests/kernels/test_fp8_quant.py
tests/kernels/test_fp8_quant.py
+2
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+13
-13
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+2
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+37
-23
No files found.
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
0 → 100644
View file @
4cc24f01
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
model_name
:
"
nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.769
-
name
:
"
exact_match,flexible-extract"
value
:
0.769
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
4cc24f01
...
...
@@ -3,4 +3,5 @@ Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
tests/kernels/test_fp8_quant.py
View file @
4cc24f01
...
...
@@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
device
=
"cuda"
)
+
1e-6
# avoid nans
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
torch
.
float8_e4m3fn
)
ops_out
,
ops_scales
=
ops
.
dynamic_per_token_scaled_fp8_quant
(
x
)
ops_out
,
ops_scales
=
ops
.
scaled_fp8_quant
(
x
,
use_per_token_if_dynamic
=
True
)
assert
torch
.
allclose
(
ref_scales
,
ops_scales
)
assert
torch
.
allclose
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
...
...
vllm/_custom_ops.py
View file @
4cc24f01
...
...
@@ -300,6 +300,7 @@ def scaled_fp8_quant(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_dim_padding
:
Optional
[
int
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
...
...
@@ -315,6 +316,8 @@ def scaled_fp8_quant(
scale: Optional scaling factor for the FP8 quantization
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
...
...
@@ -328,22 +331,19 @@ def scaled_fp8_quant(
else
:
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
float8_e4m3fn
)
if
scale
is
None
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
def
dynamic_per_token_scaled_fp8_quant
(
input
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
float8_e4m3fn
)
scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
(
output
,
input
,
scales
)
return
output
,
scales
return
output
,
scale
# int8
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
4cc24f01
...
...
@@ -103,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
True
)
vllm/model_executor/layers/quantization/fp8.py
View file @
4cc24f01
...
...
@@ -214,7 +214,8 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
False
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
4cc24f01
...
...
@@ -107,31 +107,43 @@ def apply_fp8_linear(
input_scale
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cutlass_fp8_supported
:
bool
=
True
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
)
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else
:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
batch_dim_padding
=
17
)
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
batch_dim_padding
=
17
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_activations
=
(
x_scale
.
numel
()
==
1
)
if
weight_scale
.
numel
()
==
1
:
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
...
...
@@ -139,9 +151,11 @@ def apply_fp8_linear(
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
else
:
# Fallback for channelwise case, where
the weight scales are
#
applied separately.
# Fallback for channelwise case, where
we use unfused DQ
#
due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
...
...
@@ -155,21 +169,21 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# This computes C = sx * (X * W).
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
torch
.
float32
,
scale_a
=
x_scale
)
out_dtype
=
torch
.
float32
)
# Unpad (undo batch_dim_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
# C = sw * sx * (X * W)
output
=
output
*
weight_scale
.
t
()
# DQ
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
weight_scale
.
t
()
if
bias
is
not
None
:
# C = sw * sx * (X * W) + bias
output
=
output
+
bias
output
=
output
.
to
(
dtype
=
input
.
dtype
)
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
return
output
.
to
(
dtype
=
input
.
dtype
)
def
apply_int8_linear
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment