Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2cd402e1
Unverified
Commit
2cd402e1
authored
Jun 28, 2024
by
Robert Shaw
Committed by
GitHub
Jun 28, 2024
Browse files
[ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (#5921)
Co-authored-by:
Robert Shaw
<
rshaw@neuralmagic
>
parent
b1852307
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
23 deletions
+32
-23
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+12
-2
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+20
-21
No files found.
vllm/model_executor/layers/linear.py
View file @
2cd402e1
...
@@ -383,8 +383,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -383,8 +383,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
None
)
None
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already
packed
.
# Loaded weight is already
fused on disk (qkv/mlp)
.
if
output_dim
is
None
:
if
output_dim
is
None
:
# If fp8 + scale, need to send to each shard.
if
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
loaded_shard_id
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
return
return
...
@@ -567,8 +572,13 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -567,8 +572,13 @@ class QKVParallelLinear(ColumnParallelLinear):
None
)
None
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already
packed
.
# Loaded weight is already
fused on disk (qkv/mlp)
.
if
output_dim
is
None
:
if
output_dim
is
None
:
# If fp8 + scale, need to send to each shard.
if
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
loaded_shard_id
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
return
return
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
2cd402e1
...
@@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
"""
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
fused_module_in_checkpoint
=
False
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
...
@@ -111,6 +112,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -111,6 +112,7 @@ class Fp8LinearMethod(LinearMethodBase):
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
layer
.
register_parameter
(
scale_name
,
scale
)
layer
.
register_parameter
(
scale_name
,
scale
)
set_weight_attrs
(
set_weight_attrs
(
scale
,
{
scale
,
{
...
@@ -170,10 +172,14 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -170,10 +172,14 @@ class Fp8LinearMethod(LinearMethodBase):
def
scales_shard_indexer
(
def
scales_shard_indexer
(
self
,
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
self
,
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shard_id
:
Optional
[
Union
[
str
,
int
]])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
if
isinstance
(
shard_id
,
int
):
if
shard_id
is
None
:
shard_id
=
0
self
.
fused_module_in_checkpoint
=
True
elif
isinstance
(
shard_id
,
int
):
pass
pass
elif
isinstance
(
shard_id
,
str
):
elif
isinstance
(
shard_id
,
str
):
if
shard_id
not
in
qkv_idxs
:
if
shard_id
not
in
qkv_idxs
:
...
@@ -205,11 +211,13 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -205,11 +211,13 @@ class Fp8LinearMethod(LinearMethodBase):
# WEIGHT_SCALE / WEIGHT
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
# Loop over logical weights, requantizing with single scale.
max_w_scale
=
layer
.
weight_scale
.
max
()
max_w_scale
=
layer
.
weight_scale
.
max
()
if
not
self
.
fused_module_in_checkpoint
:
start
=
0
start
=
0
for
idx
,
logical_width
in
enumerate
(
layer
.
logical_widths
):
for
idx
,
logical_width
in
enumerate
(
layer
.
logical_widths
):
end
=
start
+
logical_width
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
layer
.
weight
[
start
:
end
,
:],
weight_dq
=
per_tensor_dequantize
(
layer
.
weight_scale
[
idx
])
layer
.
weight
[
start
:
end
,
:],
layer
.
weight_scale
[
idx
])
layer
.
weight
[
start
:
end
,
:]
=
per_tensor_quantize
(
layer
.
weight
[
start
:
end
,
:]
=
per_tensor_quantize
(
weight_dq
,
layer
.
weight_scale
.
max
())
weight_dq
,
layer
.
weight_scale
.
max
())
...
@@ -227,10 +235,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -227,10 +235,6 @@ class Fp8LinearMethod(LinearMethodBase):
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
layer
.
input_scale
=
None
layer
.
input_scale
=
None
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
all_close_1d
(
layer
.
input_scale
):
raise
ValueError
(
"All the input_scales for the logical weights of a "
f
"layer must be equal. But got
{
layer
.
input_scale
}
"
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
else
:
else
:
...
@@ -317,11 +321,6 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
...
@@ -317,11 +321,6 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
del
layer
.
kv_scale
del
layer
.
kv_scale
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
,
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment