Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
09c73fb2
Unverified
Commit
09c73fb2
authored
Nov 14, 2023
by
Casper
Committed by
GitHub
Nov 14, 2023
Browse files
Fix multi-GPU loading and inference (#190)
parent
299c460b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
83 additions
and
43 deletions
+83
-43
awq/models/auto.py
awq/models/auto.py
+2
-2
awq/models/base.py
awq/models/base.py
+22
-24
awq/modules/fused/block.py
awq/modules/fused/block.py
+6
-3
awq/modules/fused/model.py
awq/modules/fused/model.py
+43
-13
awq/modules/linear.py
awq/modules/linear.py
+1
-1
awq/utils/fused_utils.py
awq/utils/fused_utils.py
+9
-0
No files found.
awq/models/auto.py
View file @
09c73fb2
...
...
@@ -45,13 +45,13 @@ class AutoAWQForCausalLM:
def
from_quantized
(
self
,
quant_path
,
quant_filename
=
''
,
max_new_tokens
=
None
,
trust_remote_code
=
True
,
fuse_layers
=
True
,
batch_size
=
1
,
safetensors
=
True
,
max_memory
=
None
,
offload_folder
=
None
,
**
config_kwargs
)
->
BaseAWQForCausalLM
:
device_map
=
"balanced"
,
offload_folder
=
None
,
**
config_kwargs
)
->
BaseAWQForCausalLM
:
os
.
environ
[
"AWQ_BATCH_SIZE"
]
=
str
(
batch_size
)
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
trust_remote_code
=
trust_remote_code
,
fuse_layers
=
fuse_layers
,
safetensors
=
safetensors
,
max_memory
=
max_memory
,
offload_folder
=
offload_folder
,
device_map
=
device_map
,
offload_folder
=
offload_folder
,
**
config_kwargs
)
awq/models/base.py
View file @
09c73fb2
...
...
@@ -14,8 +14,12 @@ from transformers.modeling_utils import shard_checkpoint
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
accelerate
import
init_empty_weights
,
load_checkpoint_in_model
,
infer_auto_device_map
from
accelerate.big_modeling
import
(
init_empty_weights
,
infer_auto_device_map
,
load_checkpoint_and_dispatch
,
)
from
accelerate.utils
import
get_balanced_memory
class
BaseAWQForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
...
...
@@ -109,9 +113,17 @@ class BaseAWQForCausalLM(nn.Module):
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
)
# Evenly distribute memory on GPUs
max_memory
=
get_balanced_memory
(
model
,
no_split_module_classes
=
[
self
.
layer_type
],
dtype
=
torch_dtype
)
# Get device map
device_map
=
infer_auto_device_map
(
model
,
max_memory
=
max_memory
,
no_split_module_classes
=
[
self
.
layer_type
],
dtype
=
torch_dtype
)
...
...
@@ -123,6 +135,7 @@ class BaseAWQForCausalLM(nn.Module):
trust_remote_code
=
trust_remote_code
,
torch_dtype
=
torch_dtype
,
use_safetensors
=
safetensors
,
device_map
=
device_map
,
**
model_init_kwargs
)
...
...
@@ -135,7 +148,7 @@ class BaseAWQForCausalLM(nn.Module):
max_new_tokens
=
None
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
True
,
is_quantized
=
True
,
fuse_layers
=
False
,
version
=
'GEMM'
,
max_memory
=
None
,
offload_folder
=
None
,
device_map
=
"balanced"
,
offload_folder
=
None
,
**
config_kwargs
):
# [STEP 1-2] Load weights path and configs
model_weights_path
,
config
,
quant_config
=
self
.
_load_config
(
...
...
@@ -153,36 +166,21 @@ class BaseAWQForCausalLM(nn.Module):
model
.
tie_weights
()
# Get device map
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
self
.
layer_type
],
max_memory
=
max_memory
,
dtype
=
torch_dtype
)
# Load checkpoint
load_checkpoint_in_model
(
# loads the weights into modules and distributes
# across available devices automatically
load_checkpoint_and_dispatch
(
model
,
checkpoint
=
model_weights_path
,
device_map
=
device_map
,
no_split_module_classes
=
[
self
.
layer_type
],
offload_folder
=
offload_folder
,
dtype
=
torch_dtype
dtype
=
torch_dtype
,
)
# Dispath to devices
if
fuse_layers
:
self
.
fuse_layers
(
model
)
# Offloading dispatch
from
accelerate
import
dispatch_model
model
=
dispatch_model
(
model
,
device_map
=
device_map
,
offload_dir
=
offload_folder
)
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
def
_load_config
(
self
,
model_path
,
model_filename
,
safetensors
=
True
,
...
...
awq/modules/fused/block.py
View file @
09c73fb2
...
...
@@ -19,6 +19,7 @@ class LlamaLikeBlock(nn.Module):
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
self
.
device
=
dev
def
forward
(
self
,
hidden_states
,
past_key_value
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
...
...
@@ -30,7 +31,7 @@ class LlamaLikeBlock(nn.Module):
attention_mask
=
attention_mask
)
h
=
hidden_states
+
attn_output
h
=
hidden_states
.
to
(
attn_output
.
device
)
+
attn_output
out
=
h
+
self
.
mlp
.
forward
(
self
.
norm_2
(
h
))
return
out
,
None
,
past_key_value
...
...
@@ -48,6 +49,7 @@ class MPTBlock(nn.Module):
).
to
(
dev
)
self
.
norm_2
=
norm_2
self
.
ffn
=
mpt_mlp
.
to
(
dev
)
self
.
device
=
dev
def
forward
(
self
,
hidden_states
,
past_key_value
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
...
...
@@ -62,7 +64,7 @@ class MPTBlock(nn.Module):
use_cache
=
True
)
h
=
hidden_states
+
attn_output
h
=
hidden_states
.
to
(
attn_output
.
device
)
+
attn_output
out
=
h
+
self
.
ffn
.
forward
(
self
.
norm_2
(
h
))
return
out
,
None
,
past_key_value
...
...
@@ -94,6 +96,7 @@ class FalconDecoderLayer(nn.Module):
self
.
input_layernorm
=
input_layernorm
# before attention
self
.
mlp
=
mlp
self
.
device
=
dev
def
_get_attention_shapes
(
self
,
n_heads
,
max_seq_len
,
head_dim
):
batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
...
...
@@ -136,7 +139,7 @@ class FalconDecoderLayer(nn.Module):
use_cache
=
True
)
h_attn
=
hidden_states
+
attn_output
h_attn
=
hidden_states
.
to
(
attn_output
.
device
)
+
attn_output
if
self
.
new_decoder_arch
:
h_mlp
=
self
.
mlp
.
forward
(
mlp_layernorm_out
)
...
...
awq/modules/fused/model.py
View file @
09c73fb2
import
torch
import
torch.nn
as
nn
from
typing
import
List
from
awq.utils
import
fused_utils
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
awq.modules.fused.block
import
MPTBlock
,
FalconDecoderLayer
,
LlamaLikeBlock
from
awq.utils.fused_utils
import
prepare_attention_mask
,
prepare_input_ids
,
prepare_cache
class
LlamaLikeModel
(
nn
.
Module
):
"""
...
...
@@ -20,17 +20,17 @@ class LlamaLikeModel(nn.Module):
@
torch
.
inference_mode
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
,
*
args
,
**
kwargs
):
input_ids
,
self
.
last_forward_num_tokens
=
prepare_input_ids
(
input_ids
,
self
.
last_forward_num_tokens
=
fused_utils
.
prepare_input_ids
(
input_ids
,
self
.
last_forward_num_tokens
)
_bsz
,
seqlen
=
input_ids
.
shape
prepare_cache
(
self
.
blocks
,
seqlen
)
fused_utils
.
prepare_cache
(
self
.
blocks
,
seqlen
)
h
=
self
.
embedding
(
input_ids
)
mask
=
prepare_attention_mask
(
mask
=
fused_utils
.
prepare_attention_mask
(
seqlen
=
seqlen
,
start_pos
=
self
.
blocks
[
0
].
attn
.
start_pos
,
device
=
input_ids
.
device
,
...
...
@@ -38,7 +38,17 @@ class LlamaLikeModel(nn.Module):
)
for
layer
in
self
.
blocks
:
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
,
mask
=
fused_utils
.
prepare_correct_devices
(
layer
,
h
,
mask
,
)
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
=
self
.
norm
(
h
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
h
,
past_key_values
=
past_key_value
,
hidden_states
=
(),
attentions
=
())
...
...
@@ -56,17 +66,17 @@ class MPTModel(nn.Module):
@
torch
.
inference_mode
()
def
forward
(
self
,
input_ids
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
,
*
args
,
**
kwargs
):
input_ids
,
self
.
last_forward_num_tokens
=
prepare_input_ids
(
input_ids
,
self
.
last_forward_num_tokens
=
fused_utils
.
prepare_input_ids
(
input_ids
,
self
.
last_forward_num_tokens
)
_bsz
,
seqlen
=
input_ids
.
shape
prepare_cache
(
self
.
blocks
,
seqlen
)
fused_utils
.
prepare_cache
(
self
.
blocks
,
seqlen
)
h
=
self
.
wte
(
input_ids
)
mask
=
prepare_attention_mask
(
mask
=
fused_utils
.
prepare_attention_mask
(
seqlen
=
seqlen
,
start_pos
=
self
.
blocks
[
0
].
attn
.
start_pos
,
device
=
input_ids
.
device
,
...
...
@@ -74,7 +84,17 @@ class MPTModel(nn.Module):
)
for
layer
in
self
.
blocks
:
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
,
mask
=
fused_utils
.
prepare_correct_devices
(
layer
,
h
,
mask
,
)
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
=
self
.
norm_f
(
h
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
h
,
past_key_values
=
past_key_value
,
hidden_states
=
(),
attentions
=
())
...
...
@@ -92,17 +112,17 @@ class FalconModel(nn.Module):
@
torch
.
inference_mode
()
def
forward
(
self
,
input_ids
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
,
*
args
,
**
kwargs
):
input_ids
,
self
.
last_forward_num_tokens
=
prepare_input_ids
(
input_ids
,
self
.
last_forward_num_tokens
=
fused_utils
.
prepare_input_ids
(
input_ids
,
self
.
last_forward_num_tokens
)
_bsz
,
seqlen
=
input_ids
.
shape
prepare_cache
(
self
.
blocks
,
seqlen
)
fused_utils
.
prepare_cache
(
self
.
blocks
,
seqlen
)
h
=
self
.
word_embeddings
(
input_ids
)
mask
=
prepare_attention_mask
(
mask
=
fused_utils
.
prepare_attention_mask
(
seqlen
=
seqlen
,
start_pos
=
self
.
blocks
[
0
].
attn
.
start_pos
,
device
=
input_ids
.
device
,
...
...
@@ -110,7 +130,17 @@ class FalconModel(nn.Module):
)
for
layer
in
self
.
blocks
:
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
,
mask
=
fused_utils
.
prepare_correct_devices
(
layer
,
h
,
mask
,
)
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
=
self
.
ln_f
(
h
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
h
,
past_key_values
=
past_key_value
,
hidden_states
=
(),
attentions
=
())
awq/modules/linear.py
View file @
09c73fb2
awq/utils/fused_utils.py
View file @
09c73fb2
import
torch
from
typing
import
List
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
def
prepare_correct_devices
(
next_layer
,
hidden_states
,
mask
):
hidden_states
=
hidden_states
.
to
(
next_layer
.
device
)
if
mask
is
not
None
:
mask
=
mask
.
to
(
next_layer
.
device
)
return
hidden_states
,
mask
def
prepare_cache
(
blocks
,
seqlen
:
int
)
->
int
:
for
block
in
blocks
:
start_pos
=
block
.
attn
.
start_pos
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment