Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
eea08aa6
Unverified
Commit
eea08aa6
authored
Oct 31, 2023
by
Casper
Committed by
GitHub
Oct 31, 2023
Browse files
AwqConfig class (#132)
parent
a7d87540
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
156 additions
and
54 deletions
+156
-54
README.md
README.md
+5
-3
awq/models/_config.py
awq/models/_config.py
+89
-0
awq/models/aquila.py
awq/models/aquila.py
+3
-5
awq/models/base.py
awq/models/base.py
+17
-28
awq/models/falcon.py
awq/models/falcon.py
+1
-2
awq/models/llama.py
awq/models/llama.py
+3
-5
awq/models/mistral.py
awq/models/mistral.py
+3
-5
awq/models/mpt.py
awq/models/mpt.py
+1
-2
examples/basic_transformers.py
examples/basic_transformers.py
+30
-0
examples/benchmark.py
examples/benchmark.py
+3
-3
examples/eval.py
examples/eval.py
+1
-1
No files found.
README.md
View file @
eea08aa6
...
@@ -74,6 +74,7 @@ The detailed support list:
...
@@ -74,6 +74,7 @@ The detailed support list:
| ---------| ----------------------------|
| ---------| ----------------------------|
| LLaMA-2 | 7B/13B/70B |
| LLaMA-2 | 7B/13B/70B |
| LLaMA | 7B/13B/30B/65B |
| LLaMA | 7B/13B/30B/65B |
| Mistral | 7B |
| Vicuna | 7B/13B |
| Vicuna | 7B/13B |
| MPT | 7B/30B |
| MPT | 7B/30B |
| Falcon | 7B/40B |
| Falcon | 7B/40B |
...
@@ -97,6 +98,8 @@ There are two versions of AWQ: GEMM and GEMV. Both names relate to how matrix mu
...
@@ -97,6 +98,8 @@ There are two versions of AWQ: GEMM and GEMV. Both names relate to how matrix mu
### Examples
### Examples
More examples can be found in the
[
examples directory
](
examples
)
.
<details>
<details>
<summary>
Quantization
</summary>
<summary>
Quantization
</summary>
...
@@ -109,7 +112,7 @@ from transformers import AutoTokenizer
...
@@ -109,7 +112,7 @@ from transformers import AutoTokenizer
model_path
=
'lmsys/vicuna-7b-v1.5'
model_path
=
'lmsys/vicuna-7b-v1.5'
quant_path
=
'vicuna-7b-v1.5-awq'
quant_path
=
'vicuna-7b-v1.5-awq'
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
"GEMM"
}
# Load model
# Load model
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
...
@@ -134,10 +137,9 @@ from awq import AutoAWQForCausalLM
...
@@ -134,10 +137,9 @@ from awq import AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
,
TextStreamer
from
transformers
import
AutoTokenizer
,
TextStreamer
quant_path
=
"casperhansen/vicuna-7b-v1.5-awq"
quant_path
=
"casperhansen/vicuna-7b-v1.5-awq"
quant_file
=
"awq_model_w4_g128.pt"
# Load model
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
quant_file
,
fuse_layers
=
True
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
fuse_layers
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant_path
,
trust_remote_code
=
True
)
streamer
=
TextStreamer
(
tokenizer
,
skip_special_tokens
=
True
)
streamer
=
TextStreamer
(
tokenizer
,
skip_special_tokens
=
True
)
...
...
awq/models/_config.py
0 → 100644
View file @
eea08aa6
import
os
import
json
import
logging
from
typing
import
Dict
from
dataclasses
import
dataclass
,
field
,
fields
from
transformers.utils.hub
import
PushToHubMixin
,
cached_file
@
dataclass
class
AwqConfig
(
PushToHubMixin
):
quant_method
:
str
=
field
(
default
=
"awq"
)
zero_point
:
bool
=
field
(
default
=
True
)
q_group_size
:
int
=
field
(
default
=
128
)
w_bit
:
int
=
field
(
default
=
4
)
version
:
str
=
field
(
default
=
"GEMM"
)
config_file_name
=
"quant_config.json"
def
save_pretrained
(
self
,
save_dir
:
str
,
**
kwargs
):
logging
.
warning
(
"`quant_config.json` is being deprecated in the future"
" in favor of quantization_config in config.json."
)
with
open
(
os
.
path
.
join
(
save_dir
,
self
.
config_file_name
),
"w+"
,
encoding
=
"utf-8"
)
as
file
:
file
.
write
(
json
.
dumps
(
self
.
to_dict
(),
indent
=
4
))
@
classmethod
def
from_dict
(
cls
,
quant_config
:
Dict
=
{}):
if
not
quant_config
:
quant_config
=
cls
()
else
:
quant_config
=
cls
(
**
quant_config
)
return
quant_config
@
classmethod
def
from_pretrained
(
cls
,
save_dir
:
str
,
**
kwargs
):
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
commit_hash
=
kwargs
.
pop
(
"_commit_hash"
,
None
)
if
os
.
path
.
isdir
(
save_dir
):
# Local
resolved_config_file
=
os
.
path
.
join
(
save_dir
,
cls
.
config_file_name
)
else
:
# Remote
resolved_config_file
=
cached_file
(
save_dir
,
cls
.
config_file_name
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
local_files_only
=
local_files_only
,
subfolder
=
subfolder
,
_raise_exceptions_for_missing_entries
=
False
,
_raise_exceptions_for_connection_errors
=
False
,
_commit_hash
=
commit_hash
,
)
if
os
.
path
.
exists
(
resolved_config_file
):
with
open
(
resolved_config_file
,
'r'
,
encoding
=
"utf-8"
)
as
file
:
loaded_config
=
json
.
loads
(
file
.
read
())
quant_config
=
cls
(
**
loaded_config
)
else
:
quant_config
=
cls
()
return
quant_config
def
to_dict
(
self
):
return
{
"zero_point"
:
self
.
zero_point
,
"q_group_size"
:
self
.
q_group_size
,
"w_bit"
:
self
.
w_bit
,
"version"
:
self
.
version
}
def
to_transformers_dict
(
self
):
return
{
"quant_method"
:
self
.
quant_method
,
"zero_point"
:
self
.
zero_point
,
"group_size"
:
self
.
q_group_size
,
"bits"
:
self
.
w_bit
,
"version"
:
self
.
version
.
lower
(),
}
awq/models/aquila.py
View file @
eea08aa6
## Reference from llama.py
## Reference from llama.py
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
typing
import
Dict
from
transformers.models.llama.modeling_llama
import
(
from
transformers.models.llama.modeling_llama
import
(
LlamaDecoderLayer
as
AquilaDecoderLayer
,
LlamaDecoderLayer
as
AquilaDecoderLayer
,
LlamaForCausalLM
as
AquilaForCausalLM
,
LlamaForCausalLM
as
AquilaForCausalLM
,
...
@@ -14,8 +13,8 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -14,8 +13,8 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_position_embeddings"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
AquilaForCausalLM
,
quant_config
:
Dict
):
def
fuse_layers
(
model
:
AquilaForCausalLM
):
fuser
=
AquilaFuser
(
model
,
quant_config
)
fuser
=
AquilaFuser
(
model
)
fuser
.
fuse_attention
()
fuser
.
fuse_attention
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_mlp
()
fuser
.
fuse_mlp
()
...
@@ -82,9 +81,8 @@ from awq.modules.fused.norm import FasterTransformerRMSNorm
...
@@ -82,9 +81,8 @@ from awq.modules.fused.norm import FasterTransformerRMSNorm
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
class
AquilaFuser
:
class
AquilaFuser
:
def
__init__
(
self
,
model
,
quant_config
):
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
model
=
model
self
.
quant_config
=
quant_config
self
.
attention_modules
:
List
[
Tuple
[
str
,
AquilaAttention
]]
=
[
self
.
attention_modules
:
List
[
Tuple
[
str
,
AquilaAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
...
...
awq/models/base.py
View file @
eea08aa6
...
@@ -4,18 +4,19 @@ import json
...
@@ -4,18 +4,19 @@ import json
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing
import
List
,
Union
,
Dict
from
typing
import
List
,
Union
from
safetensors.torch
import
save_file
from
safetensors.torch
import
save_file
from
awq.models._config
import
AwqConfig
from
awq.modules.act
import
ScaledActivation
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
awq.quantize.quantizer
import
AwqQuantizer
from
awq.quantize.quantizer
import
AwqQuantizer
from
awq.utils.utils
import
simple_dispatch_model
from
transformers.modeling_utils
import
shard_checkpoint
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
accelerate
import
init_empty_weights
,
load_checkpoint_in_model
,
infer_auto_device_map
from
accelerate
import
init_empty_weights
,
load_checkpoint_in_model
,
infer_auto_device_map
class
BaseAWQForCausalLM
(
nn
.
Module
):
class
BaseAWQForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
super
().
__init__
()
super
().
__init__
()
...
@@ -23,7 +24,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -23,7 +24,7 @@ class BaseAWQForCausalLM(nn.Module):
self
.
model_type
:
str
=
model_type
self
.
model_type
:
str
=
model_type
self
.
is_quantized
:
bool
=
is_quantized
self
.
is_quantized
:
bool
=
is_quantized
self
.
search_result
=
None
self
.
search_result
=
None
self
.
quant_config
:
Dict
=
quant_config
self
.
quant_config
:
AwqConfig
=
quant_config
def
to
(
self
,
device
:
str
):
def
to
(
self
,
device
:
str
):
return
self
.
model
.
to
(
device
)
return
self
.
model
.
to
(
device
)
...
@@ -39,18 +40,17 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -39,18 +40,17 @@ class BaseAWQForCausalLM(nn.Module):
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
):
split
=
"train"
,
text_column
=
"text"
):
self
.
quant_config
=
quant_config
self
.
quant_config
:
AwqConfig
=
AwqConfig
.
from_dict
(
quant_config
)
quant_config
[
"version"
]
=
"GEMM"
if
'version'
not
in
quant_config
.
keys
()
else
quant_config
[
"version"
]
quantizer
=
AwqQuantizer
(
quantizer
=
AwqQuantizer
(
self
,
self
.
model
,
tokenizer
,
quant_config
[
"
w_bit
"
]
,
quant_config
[
"
q_group_size
"
]
,
self
,
self
.
model
,
tokenizer
,
self
.
quant_config
.
w_bit
,
self
.
quant_config
.
q_group_size
,
quant_config
[
"
version
"
]
,
calib_data
,
split
,
text_column
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
)
)
quantizer
.
quantize
()
quantizer
.
quantize
()
self
.
is_quantized
=
True
self
.
is_quantized
=
True
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
,
quant_config
):
def
fuse_layers
(
model
):
pass
pass
def
save_quantized
(
self
,
save_dir
,
safetensors
=
False
,
shard_size
=
"10GB"
):
def
save_quantized
(
self
,
save_dir
,
safetensors
=
False
,
shard_size
=
"10GB"
):
...
@@ -61,8 +61,10 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -61,8 +61,10 @@ class BaseAWQForCausalLM(nn.Module):
def
__init__
(
self
):
super
(
EmptyModule
,
self
).
__init__
()
def
__init__
(
self
):
super
(
EmptyModule
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
def
forward
(
self
,
x
):
return
x
# Save model files with empty state dict
# Save model and config files with empty state dict
self
.
model
.
config
.
quantization_config
=
self
.
quant_config
.
to_transformers_dict
()
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
EmptyModule
().
state_dict
())
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
EmptyModule
().
state_dict
())
self
.
quant_config
.
save_pretrained
(
save_dir
)
# Remove empty state dict
# Remove empty state dict
os
.
remove
(
f
'
{
save_dir
}
/pytorch_model.bin'
)
os
.
remove
(
f
'
{
save_dir
}
/pytorch_model.bin'
)
...
@@ -89,10 +91,6 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -89,10 +91,6 @@ class BaseAWQForCausalLM(nn.Module):
if
index
is
not
None
:
if
index
is
not
None
:
with
open
(
f
'
{
save_dir
}
/
{
model_name
}
.index.json'
,
'w+'
)
as
file
:
with
open
(
f
'
{
save_dir
}
/
{
model_name
}
.index.json'
,
'w+'
)
as
file
:
file
.
write
(
json
.
dumps
(
index
,
indent
=
4
))
file
.
write
(
json
.
dumps
(
index
,
indent
=
4
))
# Save config
with
open
(
f
'
{
save_dir
}
/quant_config.json'
,
'w+'
)
as
file
:
file
.
write
(
json
.
dumps
(
self
.
quant_config
,
indent
=
4
))
@
classmethod
@
classmethod
...
@@ -146,7 +144,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -146,7 +144,7 @@ class BaseAWQForCausalLM(nn.Module):
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
)
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
)
# Prepare WQLinear layers, replace nn.Linear
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
,
quant_config
[
"
version
"
]
)
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
,
quant_config
.
version
)
model
.
tie_weights
()
model
.
tie_weights
()
...
@@ -169,7 +167,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -169,7 +167,7 @@ class BaseAWQForCausalLM(nn.Module):
# Dispath to devices
# Dispath to devices
if
fuse_layers
:
if
fuse_layers
:
self
.
fuse_layers
(
model
,
quant_config
)
self
.
fuse_layers
(
model
)
# Offloading dispatch
# Offloading dispatch
from
accelerate
import
dispatch_model
from
accelerate
import
dispatch_model
...
@@ -201,16 +199,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -201,16 +199,7 @@ class BaseAWQForCausalLM(nn.Module):
# [STEP 2] Load config and set sequence length
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
# TODO: Create BaseAWQConfig class
quant_config_path
=
f
'
{
model_path
}
/quant_config.json'
quant_config
=
AwqConfig
.
from_pretrained
(
model_path
)
if
os
.
path
.
exists
(
quant_config_path
):
with
open
(
quant_config_path
,
'r'
)
as
file
:
quant_config
=
json
.
loads
(
file
.
read
())
if
"version"
not
in
quant_config
.
keys
():
quant_config
[
"version"
]
=
version
else
:
# Default config that works for most models
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
version
}
# Load model config and set max generation length
# Load model config and set max generation length
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
...
@@ -225,7 +214,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -225,7 +214,7 @@ class BaseAWQForCausalLM(nn.Module):
def
_load_quantized_modules
(
self
,
model
,
quant_config
,
version
):
def
_load_quantized_modules
(
self
,
model
,
quant_config
,
version
):
# Real quantization of weights
# Real quantization of weights
assert
quant_config
[
"
zero_point
"
]
,
"We only support zero_point quantization now."
assert
quant_config
.
zero_point
,
"We only support zero_point quantization now."
# Get blocks of model
# Get blocks of model
layers
=
self
.
get_model_layers
(
model
)
layers
=
self
.
get_model_layers
(
model
)
...
@@ -248,8 +237,8 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -248,8 +237,8 @@ class BaseAWQForCausalLM(nn.Module):
q_linear
=
q_linear_module
.
from_linear
(
q_linear
=
q_linear_module
.
from_linear
(
module
,
module
,
quant_config
[
'
w_bit
'
]
,
quant_config
.
w_bit
,
quant_config
[
'
q_group_size
'
]
,
quant_config
.
q_group_size
,
True
True
)
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
...
...
awq/models/falcon.py
View file @
eea08aa6
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
typing
import
Dict
from
transformers.models.falcon.modeling_falcon
import
FalconDecoderLayer
as
OldFalconDecoderLayer
,
FalconForCausalLM
,
FalconAttention
from
transformers.models.falcon.modeling_falcon
import
FalconDecoderLayer
as
OldFalconDecoderLayer
,
FalconForCausalLM
,
FalconAttention
class
FalconAWQForCausalLM
(
BaseAWQForCausalLM
):
class
FalconAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"FalconDecoderLayer"
layer_type
=
"FalconDecoderLayer"
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
FalconForCausalLM
,
quant_config
:
Dict
):
def
fuse_layers
(
model
:
FalconForCausalLM
):
fuser
=
FalconFuser
(
model
)
fuser
=
FalconFuser
(
model
)
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
...
...
awq/models/llama.py
View file @
eea08aa6
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
typing
import
Dict
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -7,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -7,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_position_embeddings"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
LlamaForCausalLM
,
quant_config
:
Dict
):
def
fuse_layers
(
model
:
LlamaForCausalLM
):
fuser
=
LlamaFuser
(
model
,
quant_config
)
fuser
=
LlamaFuser
(
model
)
fuser
.
fuse_attention
()
fuser
.
fuse_attention
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_mlp
()
fuser
.
fuse_mlp
()
...
@@ -76,9 +75,8 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
...
@@ -76,9 +75,8 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
class
LlamaFuser
:
class
LlamaFuser
:
def
__init__
(
self
,
model
,
quant_config
):
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
model
=
model
self
.
quant_config
=
quant_config
self
.
attention_modules
:
List
[
Tuple
[
str
,
LlamaAttention
]]
=
[
self
.
attention_modules
:
List
[
Tuple
[
str
,
LlamaAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
...
...
awq/models/mistral.py
View file @
eea08aa6
from
typing
import
Dict
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
transformers.models.mistral.modeling_mistral
import
MistralDecoderLayer
,
MistralForCausalLM
from
transformers.models.mistral.modeling_mistral
import
MistralDecoderLayer
,
MistralForCausalLM
...
@@ -7,8 +6,8 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -7,8 +6,8 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_position_embeddings"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
MistralForCausalLM
,
quant_config
:
Dict
):
def
fuse_layers
(
model
:
MistralForCausalLM
):
fuser
=
MistralFuser
(
model
,
quant_config
)
fuser
=
MistralFuser
(
model
)
fuser
.
fuse_attention
()
fuser
.
fuse_attention
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_mlp
()
fuser
.
fuse_mlp
()
...
@@ -76,9 +75,8 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
...
@@ -76,9 +75,8 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from
transformers.models.mistral.modeling_mistral
import
MistralAttention
,
MistralRMSNorm
,
MistralMLP
from
transformers.models.mistral.modeling_mistral
import
MistralAttention
,
MistralRMSNorm
,
MistralMLP
class
MistralFuser
:
class
MistralFuser
:
def
__init__
(
self
,
model
,
quant_config
):
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
model
=
model
self
.
quant_config
=
quant_config
self
.
attention_modules
:
List
[
Tuple
[
str
,
MistralAttention
]]
=
[
self
.
attention_modules
:
List
[
Tuple
[
str
,
MistralAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
...
...
awq/models/mpt.py
View file @
eea08aa6
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
typing
import
Dict
from
transformers.models.mpt.modeling_mpt
import
MptBlock
as
OldMptBlock
,
MptForCausalLM
from
transformers.models.mpt.modeling_mpt
import
MptBlock
as
OldMptBlock
,
MptForCausalLM
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -7,7 +6,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -7,7 +6,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_seq_len"
max_new_tokens_key
=
"max_seq_len"
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
Dict
):
def
fuse_layers
(
model
:
MptForCausalLM
):
fuser
=
MptFuser
(
model
)
fuser
=
MptFuser
(
model
)
fuser
.
fuse_transformer
()
fuser
.
fuse_transformer
()
...
...
examples/basic_transformers.py
0 → 100644
View file @
eea08aa6
import
torch
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
TextStreamer
# NOTE: Must install from PR until merged
# pip install --upgrade git+https://github.com/younesbelkada/transformers.git@add-awq
model_id
=
"casperhansen/mistral-7b-instruct-v0.1-awq"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
torch_dtype
=
torch
.
float16
,
low_cpu_mem_usage
=
True
,
device_map
=
"cuda:0"
)
streamer
=
TextStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
# Convert prompt to tokens
text
=
"[INST] What are the basic steps to use the Huggingface transformers library? [/INST]"
tokens
=
tokenizer
(
text
,
return_tensors
=
'pt'
).
input_ids
.
cuda
()
# Generate output
generation_output
=
model
.
generate
(
tokens
,
streamer
=
streamer
,
max_new_tokens
=
512
)
\ No newline at end of file
examples/benchmark.py
View file @
eea08aa6
...
@@ -85,7 +85,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safeten
...
@@ -85,7 +85,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safeten
"Prefill tokens/s"
:
prefill_tokens_per_second
,
"Prefill tokens/s"
:
prefill_tokens_per_second
,
"Decode tokens/s"
:
decode_tokens_per_second
,
"Decode tokens/s"
:
decode_tokens_per_second
,
"Memory (VRAM)"
:
f
"
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
"Memory (VRAM)"
:
f
"
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
},
model
.
quant_config
[
"
version
"
]
},
model
.
quant_config
.
version
def
main
(
args
):
def
main
(
args
):
rounds
=
[
rounds
=
[
...
@@ -126,8 +126,8 @@ def main(args):
...
@@ -126,8 +126,8 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"casperhansen/
vicuna-7b-v1.5
-awq"
,
help
=
"path to the model"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"casperhansen/
mistral-7b-instruct-v0.1
-awq"
,
help
=
"path to the model"
)
parser
.
add_argument
(
"--quant_file"
,
type
=
str
,
default
=
"
awq_model_w4_g128.pt
"
,
help
=
"weights filename"
)
parser
.
add_argument
(
"--quant_file"
,
type
=
str
,
default
=
""
,
help
=
"weights filename"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for cache and generation"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for cache and generation"
)
parser
.
add_argument
(
"--safetensors"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Use for enabling safetensors"
)
parser
.
add_argument
(
"--safetensors"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Use for enabling safetensors"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
examples/eval.py
View file @
eea08aa6
...
@@ -33,7 +33,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
...
@@ -33,7 +33,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
"""
"""
- Run perplexity of quantized model:
- Run perplexity of quantized model:
python examples/eval.py --model_path
vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
python examples/eval.py --model_path
casperhansen/mistral-7b-instruct-v0.1-awq
- Run perplexity unquantized FP16 model:
- Run perplexity unquantized FP16 model:
python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5
python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment