Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
226688bd
Unverified
Commit
226688bd
authored
Oct 29, 2024
by
Michael Goin
Committed by
GitHub
Oct 29, 2024
Browse files
[Bugfix][VLM] Make apply_fp8_linear work with >2D input (#9812)
parent
64cb1cdc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
13 deletions
+20
-13
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+20
-13
No files found.
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
226688bd
...
@@ -96,21 +96,26 @@ def apply_fp8_linear(
...
@@ -96,21 +96,26 @@ def apply_fp8_linear(
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
1
]]
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if
cutlass_fp8_supported
:
if
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input
_2d
,
input_scale
,
input_scale
,
scale_ub
=
input_scale_ub
,
scale_ub
=
input_scale_ub
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
# Fused GEMM_DQ
# Fused GEMM_DQ
return
ops
.
cutlass_scaled_mm
(
qinput
,
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
weight
,
out_dtype
=
input
.
dtype
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
# torch.scaled_mm supports per tensor weights + activations only
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
# so fallback to naive if per channel or per token
...
@@ -119,7 +124,7 @@ def apply_fp8_linear(
...
@@ -119,7 +124,7 @@ def apply_fp8_linear(
# for matrices with batch dimension > 16.
# for matrices with batch dimension > 16.
# This could change in the future.
# This could change in the future.
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input
_2d
,
input_scale
,
input_scale
,
num_token_padding
=
17
,
num_token_padding
=
17
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
...
@@ -138,8 +143,10 @@ def apply_fp8_linear(
...
@@ -138,8 +143,10 @@ def apply_fp8_linear(
# A fix for discrepancy in scaled_mm which returns tuple
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
return
torch
.
narrow
(
output
[
0
],
0
,
0
,
input
.
shape
[
0
])
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
else
:
else
:
# Fallback for channelwise case, where we use unfused DQ
# Fallback for channelwise case, where we use unfused DQ
...
@@ -176,15 +183,15 @@ def apply_fp8_linear(
...
@@ -176,15 +183,15 @@ def apply_fp8_linear(
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
output
=
output
[
0
]
# Unpad (undo num_token_padding)
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
_2d
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input
_2d
.
shape
[
0
])
# DQ
# DQ
# C = sw * sx * (X * W) + bias
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
weight_scale
.
t
()
output
=
output
*
x_scale
*
weight_scale
.
t
()
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
)
return
output
.
to
(
dtype
=
input
.
dtype
)
.
view
(
*
output_shape
)
def
apply_int8_linear
(
def
apply_int8_linear
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment