Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
63346c34
Commit
63346c34
authored
Aug 24, 2023
by
Casper Hansen
Browse files
Integrate fused modules into AWQ model loading
parent
870a9dc9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
45 additions
and
11 deletions
+45
-11
awq/models/auto.py
awq/models/auto.py
+2
-2
awq/models/base.py
awq/models/base.py
+4
-1
awq/models/llama.py
awq/models/llama.py
+7
-0
awq/models/mpt.py
awq/models/mpt.py
+5
-0
awq/modules/__init__.py
awq/modules/__init__.py
+0
-0
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+0
-0
awq/modules/fused_mlp.py
awq/modules/fused_mlp.py
+27
-1
awq/modules/fused_norm.py
awq/modules/fused_norm.py
+0
-0
tinychat/demo.py
tinychat/demo.py
+0
-7
No files found.
awq/models/auto.py
View file @
63346c34
...
@@ -33,9 +33,9 @@ class AutoAWQForCausalLM:
...
@@ -33,9 +33,9 @@ class AutoAWQForCausalLM:
@
classmethod
@
classmethod
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
max_new_tokens
=
None
,
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
device
=
'balanced'
,
trust_remote_code
=
True
,
fuse_layers
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
device
,
trust_remote_code
=
trust_remote_code
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
device
,
trust_remote_code
=
trust_remote_code
,
fuse_layers
=
fuse_layers
)
)
\ No newline at end of file
awq/models/base.py
View file @
63346c34
...
@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module):
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
max_new_tokens
=
None
,
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
):
safetensors
=
False
,
is_quantized
=
True
,
fuse_layers
=
False
):
# [STEP 1] Download model if path is not a directory
# [STEP 1] Download model if path is not a directory
if
not
os
.
path
.
isdir
(
model_path
):
if
not
os
.
path
.
isdir
(
model_path
):
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
...
@@ -298,6 +298,9 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -298,6 +298,9 @@ class BaseAWQForCausalLM(nn.Module):
if
is_quantized
:
if
is_quantized
:
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
if
fuse_layers
:
self
.
fuse_layers
(
model
)
else
:
else
:
# If not quantized, must load with AutoModelForCausalLM
# If not quantized, must load with AutoModelForCausalLM
device_map
=
infer_auto_device_map
(
device_map
=
infer_auto_device_map
(
...
...
awq/models/llama.py
View file @
63346c34
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_quant_norm
,
make_quant_attn
,
make_fused_mlp
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"LlamaDecoderLayer"
layer_type
=
"LlamaDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
awq_model
):
make_quant_attn
(
awq_model
,
awq_model
.
device
)
make_quant_norm
(
awq_model
)
make_fused_mlp
(
awq_model
)
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
LlamaForCausalLM
):
def
get_model_layers
(
model
:
LlamaForCausalLM
):
return
model
.
model
.
layers
return
model
.
model
.
layers
...
...
awq/models/mpt.py
View file @
63346c34
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_fused_mlp
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
layer_type
=
"MPTBlock"
max_new_tokens_key
=
"max_seq_len"
max_new_tokens_key
=
"max_seq_len"
@
staticmethod
def
fuse_layers
(
awq_model
):
make_fused_mlp
(
awq_model
)
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
):
def
get_model_layers
(
model
):
return
model
.
transformer
.
blocks
return
model
.
transformer
.
blocks
...
...
tinychat
/modules/__init__.py
→
awq
/modules/__init__.py
View file @
63346c34
File moved
tinychat
/modules/fused_attn.py
→
awq
/modules/fused_attn.py
View file @
63346c34
File moved
tinychat
/modules/fused_mlp.py
→
awq
/modules/fused_mlp.py
View file @
63346c34
...
@@ -7,6 +7,27 @@ from transformers.models.llama.modeling_llama import LlamaMLP
...
@@ -7,6 +7,27 @@ from transformers.models.llama.modeling_llama import LlamaMLP
import
awq_inference_engine
import
awq_inference_engine
class
QuantMPTMLP
(
nn
.
Module
):
def
__init__
(
self
,
up_proj
,
act
,
down_proj
):
super
().
__init__
()
self
.
register_buffer
(
'up_proj_qweight'
,
up_proj
.
qweight
)
self
.
register_buffer
(
'up_proj_scales'
,
up_proj
.
scales
)
self
.
register_buffer
(
'up_proj_qzeros'
,
up_proj
.
qzeros
)
self
.
up_proj
=
up_proj
self
.
act
=
act
self
.
down_proj
=
down_proj
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
awq_inference_engine
.
gemm_forward_cuda
(
x
,
self
.
up_proj_qweight
,
self
.
up_proj_scales
,
self
.
up_proj_qzeros
,
8
)
return
self
.
down_proj
(
self
.
act
(
x
))
class
QuantLlamaMLP
(
nn
.
Module
):
class
QuantLlamaMLP
(
nn
.
Module
):
...
@@ -57,10 +78,15 @@ def make_fused_mlp(m, parent_name=''):
...
@@ -57,10 +78,15 @@ def make_fused_mlp(m, parent_name=''):
"""
"""
if
isinstance
(
m
,
LlamaMLP
):
if
isinstance
(
m
,
LlamaMLP
):
return
QuantLlamaMLP
(
m
.
gate_proj
,
m
.
down_proj
,
m
.
up_proj
)
return
QuantLlamaMLP
(
m
.
gate_proj
,
m
.
down_proj
,
m
.
up_proj
)
elif
"mptmlp"
in
str
(
m
.
__class__
).
lower
():
return
QuantMPTMLP
(
m
.
up_proj
,
m
.
act
,
m
.
down_proj
)
for
name
,
child
in
m
.
named_children
():
for
name
,
child
in
m
.
named_children
():
child
=
make_fused_mlp
(
child
,
parent_name
=
f
"
{
parent_name
}
.
{
name
}
"
)
child
=
make_fused_mlp
(
child
,
parent_name
=
f
"
{
parent_name
}
.
{
name
}
"
)
if
isinstance
(
child
,
QuantLlamaMLP
):
if
isinstance
(
child
,
QuantLlamaMLP
):
setattr
(
m
,
name
,
child
)
setattr
(
m
,
name
,
child
)
return
m
elif
isinstance
(
child
,
QuantMPTMLP
):
setattr
(
m
,
name
,
child
)
return
m
\ No newline at end of file
tinychat
/modules/fused_norm.py
→
awq
/modules/fused_norm.py
View file @
63346c34
File moved
tinychat/demo.py
View file @
63346c34
...
@@ -116,13 +116,6 @@ if __name__ == '__main__':
...
@@ -116,13 +116,6 @@ if __name__ == '__main__':
else
:
else
:
stream_generator
=
StreamGenerator
stream_generator
=
StreamGenerator
# Optimize AWQ quantized model
if
args
.
precision
==
"W4A16"
and
isinstance
(
model
,
LlamaAWQForCausalLM
):
from
tinychat.modules
import
make_quant_norm
,
make_quant_attn
,
make_fused_mlp
make_quant_attn
(
model
.
model
,
args
.
device
)
make_quant_norm
(
model
.
model
)
make_fused_mlp
(
model
.
model
)
model_prompter
=
get_prompter
(
model
,
args
.
model_path
)
model_prompter
=
get_prompter
(
model
,
args
.
model_path
)
stop_token_ids
=
get_stop_token_ids
(
model
,
args
.
model_path
)
stop_token_ids
=
get_stop_token_ids
(
model
,
args
.
model_path
)
count
=
0
count
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment