Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b6374e09
Unverified
Commit
b6374e09
authored
Nov 22, 2024
by
Isotr0py
Committed by
GitHub
Nov 22, 2024
Browse files
[Bugfix] Fix Phi-3 BNB quantization with tensor parallel (#9948)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
a111d015
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
6 deletions
+56
-6
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+14
-5
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+42
-1
No files found.
vllm/model_executor/layers/linear.py
View file @
b6374e09
import
itertools
from
abc
import
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -41,12 +42,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
def
adjust_bitsandbytes_4bit_shard
(
param
:
Parameter
,
qkv
_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
shard
_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
Tuple
[
int
,
int
]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total
,
_
=
qkv
_offsets
[
"total"
]
orig_offset
,
orig_size
=
qkv
_offsets
[
loaded_shard_id
]
total
,
_
=
shard
_offsets
[
"total"
]
orig_offset
,
orig_size
=
shard
_offsets
[
loaded_shard_id
]
quantized_total
=
param
.
data
.
shape
[
0
]
quantized_offset
=
orig_offset
*
quantized_total
//
total
...
...
@@ -499,9 +500,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
if
use_bitsandbytes_4bit
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
//
2
shard_offset
=
shard_size
*
shard_id
index
=
list
(
itertools
.
accumulate
([
0
]
+
self
.
output_sizes
))
orig_offsets
=
{
str
(
i
):
(
index
[
i
],
size
)
for
i
,
size
in
enumerate
(
self
.
output_sizes
)
}
orig_offsets
[
"total"
]
=
(
self
.
output_size
,
0
)
shard_size
,
shard_offset
=
adjust_bitsandbytes_4bit_shard
(
param
,
orig_offsets
,
str
(
shard_id
))
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
...
...
vllm/model_executor/model_loader/loader.py
View file @
b6374e09
...
...
@@ -5,6 +5,7 @@ import dataclasses
import
fnmatch
import
glob
import
inspect
import
itertools
import
json
import
math
import
os
...
...
@@ -27,7 +28,9 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ReplicatedLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizeMethodBase
)
...
...
@@ -936,6 +939,34 @@ class BitsAndBytesModelLoader(BaseModelLoader):
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
weight_sub_tensor
=
weight_tensor
[...,
start_index
:
end_index
]
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif
any
(
weight_name
.
startswith
(
module
)
for
module
in
self
.
maybe_fused_weights_modules
):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes
=
next
(
(
sizes
for
module
,
sizes
in
self
.
maybe_fused_weights_modules
.
items
()
if
weight_name
.
startswith
(
module
)))
total_size
=
weight_tensor
.
size
(
0
)
assert
total_size
==
sum
(
total_shard_sizes
)
# get the start/end index of each shard weight tensor
total_start_index
=
list
(
itertools
.
accumulate
([
0
]
+
total_shard_sizes
))[:
-
1
]
shard_weights_index
=
[
(
idx
+
size
//
tp_size
*
tp_rank
,
idx
+
size
//
tp_size
*
(
tp_rank
+
1
))
for
idx
,
size
in
zip
(
total_start_index
,
total_shard_sizes
)
]
# slice and reorder the weight tensor
weight_tensor
=
[
weight_tensor
[
start_index
:
end_index
,
...]
for
start_index
,
end_index
in
shard_weights_index
]
weight_sub_tensor
=
torch
.
cat
(
weight_tensor
,
dim
=
0
)
# Shard by row
else
:
total_size
=
weight_tensor
.
size
(
0
)
...
...
@@ -985,12 +1016,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
else
:
self
.
target_modules
=
self
.
default_target_modules
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self
.
maybe_fused_weights_modules
:
Dict
[
str
,
List
[
int
]]
=
{}
for
name
,
module
in
model
.
named_modules
():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if
isinstance
(
module
,
(
ReplicatedLinear
,
)):
self
.
unsharded_weights_modules
.
append
(
name
)
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
# fused weights on disk. We need to use the output sizes of these
# modules to shard the weights correctly.
elif
isinstance
(
module
,
(
QKVParallelLinear
,
MergedColumnParallelLinear
)):
self
.
maybe_fused_weights_modules
[
name
]
=
module
.
output_sizes
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif
isinstance
(
module
,
(
RowParallelLinear
,
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment