Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
63346c34
"src/vscode:/vscode.git/clone" did not exist on "ef3844d3a83583f36d0166be6753d062b3cbd7dc"
Commit
63346c34
authored
Aug 24, 2023
by
Casper Hansen
Browse files
Integrate fused modules into AWQ model loading
parent
870a9dc9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
45 additions
and
11 deletions
+45
-11
awq/models/auto.py
awq/models/auto.py
+2
-2
awq/models/base.py
awq/models/base.py
+4
-1
awq/models/llama.py
awq/models/llama.py
+7
-0
awq/models/mpt.py
awq/models/mpt.py
+5
-0
awq/modules/__init__.py
awq/modules/__init__.py
+0
-0
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+0
-0
awq/modules/fused_mlp.py
awq/modules/fused_mlp.py
+27
-1
awq/modules/fused_norm.py
awq/modules/fused_norm.py
+0
-0
tinychat/demo.py
tinychat/demo.py
+0
-7
No files found.
awq/models/auto.py
View file @
63346c34
...
...
@@ -33,9 +33,9 @@ class AutoAWQForCausalLM:
@
classmethod
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
device
=
'balanced'
,
trust_remote_code
=
True
,
fuse_layers
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
device
,
trust_remote_code
=
trust_remote_code
quant_path
,
model_type
,
quant_filename
,
max_new_tokens
,
device
,
trust_remote_code
=
trust_remote_code
,
fuse_layers
=
fuse_layers
)
\ No newline at end of file
awq/models/base.py
View file @
63346c34
...
...
@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module):
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
):
safetensors
=
False
,
is_quantized
=
True
,
fuse_layers
=
False
):
# [STEP 1] Download model if path is not a directory
if
not
os
.
path
.
isdir
(
model_path
):
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
...
...
@@ -298,6 +298,9 @@ class BaseAWQForCausalLM(nn.Module):
if
is_quantized
:
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
if
fuse_layers
:
self
.
fuse_layers
(
model
)
else
:
# If not quantized, must load with AutoModelForCausalLM
device_map
=
infer_auto_device_map
(
...
...
awq/models/llama.py
View file @
63346c34
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_quant_norm
,
make_quant_attn
,
make_fused_mlp
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"LlamaDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
awq_model
):
make_quant_attn
(
awq_model
,
awq_model
.
device
)
make_quant_norm
(
awq_model
)
make_fused_mlp
(
awq_model
)
@
staticmethod
def
get_model_layers
(
model
:
LlamaForCausalLM
):
return
model
.
model
.
layers
...
...
awq/models/mpt.py
View file @
63346c34
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_fused_mlp
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
max_new_tokens_key
=
"max_seq_len"
@
staticmethod
def
fuse_layers
(
awq_model
):
make_fused_mlp
(
awq_model
)
@
staticmethod
def
get_model_layers
(
model
):
return
model
.
transformer
.
blocks
...
...
tinychat
/modules/__init__.py
→
awq
/modules/__init__.py
View file @
63346c34
File moved
tinychat
/modules/fused_attn.py
→
awq
/modules/fused_attn.py
View file @
63346c34
File moved
tinychat
/modules/fused_mlp.py
→
awq
/modules/fused_mlp.py
View file @
63346c34
...
...
@@ -7,6 +7,27 @@ from transformers.models.llama.modeling_llama import LlamaMLP
import
awq_inference_engine
class
QuantMPTMLP
(
nn
.
Module
):
def
__init__
(
self
,
up_proj
,
act
,
down_proj
):
super
().
__init__
()
self
.
register_buffer
(
'up_proj_qweight'
,
up_proj
.
qweight
)
self
.
register_buffer
(
'up_proj_scales'
,
up_proj
.
scales
)
self
.
register_buffer
(
'up_proj_qzeros'
,
up_proj
.
qzeros
)
self
.
up_proj
=
up_proj
self
.
act
=
act
self
.
down_proj
=
down_proj
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
awq_inference_engine
.
gemm_forward_cuda
(
x
,
self
.
up_proj_qweight
,
self
.
up_proj_scales
,
self
.
up_proj_qzeros
,
8
)
return
self
.
down_proj
(
self
.
act
(
x
))
class
QuantLlamaMLP
(
nn
.
Module
):
...
...
@@ -57,10 +78,15 @@ def make_fused_mlp(m, parent_name=''):
"""
if
isinstance
(
m
,
LlamaMLP
):
return
QuantLlamaMLP
(
m
.
gate_proj
,
m
.
down_proj
,
m
.
up_proj
)
elif
"mptmlp"
in
str
(
m
.
__class__
).
lower
():
return
QuantMPTMLP
(
m
.
up_proj
,
m
.
act
,
m
.
down_proj
)
for
name
,
child
in
m
.
named_children
():
child
=
make_fused_mlp
(
child
,
parent_name
=
f
"
{
parent_name
}
.
{
name
}
"
)
if
isinstance
(
child
,
QuantLlamaMLP
):
setattr
(
m
,
name
,
child
)
return
m
elif
isinstance
(
child
,
QuantMPTMLP
):
setattr
(
m
,
name
,
child
)
return
m
\ No newline at end of file
tinychat
/modules/fused_norm.py
→
awq
/modules/fused_norm.py
View file @
63346c34
File moved
tinychat/demo.py
View file @
63346c34
...
...
@@ -116,13 +116,6 @@ if __name__ == '__main__':
else
:
stream_generator
=
StreamGenerator
# Optimize AWQ quantized model
if
args
.
precision
==
"W4A16"
and
isinstance
(
model
,
LlamaAWQForCausalLM
):
from
tinychat.modules
import
make_quant_norm
,
make_quant_attn
,
make_fused_mlp
make_quant_attn
(
model
.
model
,
args
.
device
)
make_quant_norm
(
model
.
model
)
make_fused_mlp
(
model
.
model
)
model_prompter
=
get_prompter
(
model
,
args
.
model_path
)
stop_token_ids
=
get_stop_token_ids
(
model
,
args
.
model_path
)
count
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment