Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9855b995
Unverified
Commit
9855b995
authored
Sep 17, 2024
by
chenqianfzh
Committed by
GitHub
Sep 17, 2024
Browse files
[Feature][kernel] tensor parallelism with bitsandbytes quantization (#8434)
parent
1009e93c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
17 deletions
+80
-17
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+21
-5
vllm/config.py
vllm/config.py
+0
-6
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+16
-5
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+43
-1
No files found.
tests/quantization/test_bitsandbytes.py
View file @
9855b995
...
@@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
...
@@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name
)
model_name
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
'Test requires at least 2 GPUs.'
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_4bit_to_test
)
@
fork_new_process_for_each_test
def
test_load_tp_4bit_bnb_model
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_name
,
description
)
->
None
:
hf_model_kwargs
=
{
"load_in_4bit"
:
True
}
validate_generated_texts
(
hf_runner
,
vllm_runner
,
example_prompts
[:
1
],
model_name
,
hf_model_kwargs
,
vllm_tp_size
=
2
)
def
log_generated_texts
(
prompts
,
outputs
,
runner_name
):
def
log_generated_texts
(
prompts
,
outputs
,
runner_name
):
logged_texts
=
[]
logged_texts
=
[]
for
i
,
(
_
,
generated_text
)
in
enumerate
(
outputs
):
for
i
,
(
_
,
generated_text
)
in
enumerate
(
outputs
):
...
@@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner,
...
@@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner,
vllm_runner
,
vllm_runner
,
prompts
,
prompts
,
model_name
,
model_name
,
hf_model_kwargs
=
None
):
hf_model_kwargs
=
None
,
vllm_tp_size
=
1
):
# NOTE: run vLLM first, as it requires a clean process
# NOTE: run vLLM first, as it requires a clean process
# when using distributed inference
# when using distributed inference
#Run with vLLM runner
with
vllm_runner
(
model_name
,
with
vllm_runner
(
model_name
,
quantization
=
'bitsandbytes'
,
quantization
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
tensor_parallel_size
=
vllm_tp_size
,
enforce_eager
=
True
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.8
)
as
llm
:
gpu_memory_utilization
=
0.8
)
as
llm
:
vllm_outputs
=
llm
.
generate_greedy
(
prompts
,
8
)
vllm_outputs
=
llm
.
generate_greedy
(
prompts
,
8
)
vllm_logs
=
log_generated_texts
(
prompts
,
vllm_outputs
,
"VllmRunner"
)
vllm_logs
=
log_generated_texts
(
prompts
,
vllm_outputs
,
"VllmRunner"
)
# Clean up the GPU memory for the next test
# Clean up the GPU memory for the next test
torch
.
cuda
.
synchronize
()
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner,
...
@@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner,
hf_logs
=
log_generated_texts
(
prompts
,
hf_outputs
,
"HfRunner"
)
hf_logs
=
log_generated_texts
(
prompts
,
hf_outputs
,
"HfRunner"
)
# Clean up the GPU memory for the next test
# Clean up the GPU memory for the next test
torch
.
cuda
.
synchronize
()
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
vllm/config.py
View file @
9855b995
...
@@ -393,12 +393,6 @@ class ModelConfig:
...
@@ -393,12 +393,6 @@ class ModelConfig:
"Pipeline parallelism is only supported for the following "
"Pipeline parallelism is only supported for the following "
f
" architectures:
{
_PP_SUPPORTED_MODELS
}
."
)
f
" architectures:
{
_PP_SUPPORTED_MODELS
}
."
)
if
self
.
quantization
==
"bitsandbytes"
and
(
parallel_config
.
tensor_parallel_size
>
1
or
parallel_config
.
pipeline_parallel_size
>
1
):
raise
ValueError
(
"BitAndBytes quantization with TP or PP is not supported yet."
)
# Remove the constraint after the bitsandbytes issue is fixed:
# Remove the constraint after the bitsandbytes issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
if
self
.
quantization
==
"bitsandbytes"
and
self
.
enforce_eager
is
False
:
if
self
.
quantization
==
"bitsandbytes"
and
self
.
enforce_eager
is
False
:
...
...
vllm/model_executor/layers/linear.py
View file @
9855b995
...
@@ -530,8 +530,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -530,8 +530,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
# bitsandbytes loads the weights of the specific portion
shard_size
)
# no need to narrow here
if
not
use_bitsandbytes_4bit
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
# Special case for AQLM codebooks.
elif
is_metadata
:
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
# metadata indicates fixed size concatenated along dim 0
...
@@ -899,8 +902,13 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -899,8 +902,13 @@ class QKVParallelLinear(ColumnParallelLinear):
else
:
else
:
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
start_idx
=
shard_id
*
shard_size
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
not
use_bitsandbytes_4bit
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for for AQLM codebooks.
# Special case for for AQLM codebooks.
elif
is_metadata
:
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
# metadata indicates fixed size concatenated along dim 0
...
@@ -1000,6 +1008,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1000,6 +1008,7 @@ class RowParallelLinear(LinearBase):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
# Special case for GGUF
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
@@ -1015,7 +1024,9 @@ class RowParallelLinear(LinearBase):
...
@@ -1015,7 +1024,9 @@ class RowParallelLinear(LinearBase):
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
param_data
=
param
.
data
if
input_dim
is
not
None
:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
input_dim
is
not
None
and
not
use_bitsandbytes_4bit
:
shard_size
=
param_data
.
shape
[
input_dim
]
shard_size
=
param_data
.
shape
[
input_dim
]
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
9855b995
...
@@ -22,6 +22,8 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
...
@@ -22,6 +22,8 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -689,6 +691,8 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -689,6 +691,8 @@ class ShardedStateLoader(BaseModelLoader):
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
"""Model loader to load model weights with BitAndBytes quantization."""
"""Model loader to load model weights with BitAndBytes quantization."""
# TODO: these module names are for Llama only,
# change so that it works with other models as well
default_target_modules
=
[
default_target_modules
=
[
"gate_proj"
,
"down_proj"
,
"up_proj"
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
"gate_proj"
,
"down_proj"
,
"up_proj"
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
"o_proj"
"o_proj"
...
@@ -911,13 +915,44 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -911,13 +915,44 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
quantize_4bit
from
bitsandbytes.functional
import
quantize_4bit
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# weight partitions of different modules occur at
# different dimensions
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if
'down_proj'
in
weight_name
or
'o_proj'
in
weight_name
:
total_size
=
weight_tensor
.
size
(
-
1
)
start_index
=
total_size
//
tp_size
*
tp_rank
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
weight_sub_tensor
=
weight_tensor
[...,
start_index
:
end_index
]
else
:
total_size
=
weight_tensor
.
size
(
0
)
start_index
=
total_size
//
tp_size
*
tp_rank
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
weight_sub_tensor
=
weight_tensor
[
start_index
:
end_index
,
...]
# bitsandbytes requires data in GPU
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
if
weight_sub_tensor
.
is_cuda
:
loaded_weight
=
weight_sub_tensor
else
:
loaded_weight
=
weight_sub_tensor
.
cuda
()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if
loaded_weight
.
is_contiguous
()
is
False
:
loaded_weight
=
loaded_weight
.
contiguous
()
with
set_default_torch_dtype
(
torch
.
float32
):
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
loaded_weight
,
...
@@ -958,6 +993,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -958,6 +993,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f
"BitsAndBytes loader does not support
{
quant_method
}
"
f
"BitsAndBytes loader does not support
{
quant_method
}
"
"quantization"
)
"quantization"
)
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if
pre_quant
and
get_tensor_model_parallel_world_size
()
>
1
:
raise
ValueError
(
"Prequant BitsAndBytes models with TP is not supported."
"Please try with PP."
)
load_8bit
=
False
load_8bit
=
False
if
pre_quant
:
if
pre_quant
:
load_8bit
=
quant_config
.
get
(
'load_in_8bit'
,
False
)
load_8bit
=
quant_config
.
get
(
'load_in_8bit'
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment