Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
727172e9
Unverified
Commit
727172e9
authored
Dec 23, 2023
by
Casper
Committed by
GitHub
Dec 23, 2023
Browse files
Fused rope theta (#270)
Co-authored-by:
Casper Hansen
<
casperbh96@gmail.com
>
parent
5b9f3c47
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
23 additions
and
132 deletions
+23
-132
awq/models/llama.py
awq/models/llama.py
+2
-1
awq/models/mixtral.py
awq/models/mixtral.py
+3
-3
awq/models/yi.py
awq/models/yi.py
+2
-1
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+8
-5
awq/modules/fused/block.py
awq/modules/fused/block.py
+7
-4
awq/utils/lm_eval_adaptor.py
awq/utils/lm_eval_adaptor.py
+0
-114
examples/eval.py
examples/eval.py
+1
-4
No files found.
awq/models/llama.py
View file @
727172e9
...
@@ -118,7 +118,8 @@ class LlamaFuser:
...
@@ -118,7 +118,8 @@ class LlamaFuser:
norm_1
=
norm_1
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
norm_2
=
norm_2
,
dev
=
device
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
rope_theta
=
self
.
model
.
config
.
rope_theta
))
))
self
.
model
.
model
=
LlamaLikeModel
(
self
.
model
.
model
=
LlamaLikeModel
(
...
...
awq/models/mixtral.py
View file @
727172e9
...
@@ -18,8 +18,7 @@ class MixtralAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -18,8 +18,7 @@ class MixtralAWQForCausalLM(BaseAWQForCausalLM):
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
OldMixtralForCausalLM
):
def
fuse_layers
(
model
:
OldMixtralForCausalLM
):
fuser
=
MixtralFuser
(
model
)
fuser
=
MixtralFuser
(
model
)
# TODO: Fix perplexity on fusing Mixtral
fuser
.
fuse_transformer
()
#fuser.fuse_transformer()
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
OldMixtralForCausalLM
):
def
get_model_layers
(
model
:
OldMixtralForCausalLM
):
...
@@ -125,7 +124,8 @@ class MixtralFuser:
...
@@ -125,7 +124,8 @@ class MixtralFuser:
norm_1
=
norm_1
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
norm_2
=
norm_2
,
dev
=
device
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
rope_theta
=
self
.
model
.
config
.
rope_theta
))
))
self
.
model
.
model
=
MixtralModel
(
self
.
model
.
model
=
MixtralModel
(
...
...
awq/models/yi.py
View file @
727172e9
...
@@ -113,7 +113,8 @@ class YiFuser:
...
@@ -113,7 +113,8 @@ class YiFuser:
norm_1
=
norm_1
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
norm_2
=
norm_2
,
dev
=
device
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
rope_theta
=
self
.
model
.
config
.
rope_theta
))
))
self
.
model
.
model
=
LlamaLikeModel
(
self
.
model
.
model
=
LlamaLikeModel
(
...
...
awq/modules/fused/attn.py
View file @
727172e9
...
@@ -23,11 +23,13 @@ if HF_NEW_CACHE_FORMAT:
...
@@ -23,11 +23,13 @@ if HF_NEW_CACHE_FORMAT:
class
RoPE
(
nn
.
Module
):
class
RoPE
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
max_seq_len
,
device
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
max_seq_len
,
device
,
rope_theta
):
super
(
RoPE
,
self
).
__init__
()
super
(
RoPE
,
self
).
__init__
()
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
precompute_freqs_cis
(
hidden_size
//
n_heads
,
max_seq_len
*
2
).
to
(
device
),
self
.
precompute_freqs_cis
(
hidden_size
//
n_heads
,
max_seq_len
*
2
,
rope_theta
).
to
(
device
),
requires_grad
=
False
requires_grad
=
False
)
)
...
@@ -97,7 +99,7 @@ class ALiBi(nn.Module):
...
@@ -97,7 +99,7 @@ class ALiBi(nn.Module):
class
QuantAttentionFused
(
nn
.
Module
):
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
None
):
use_alibi
=
False
,
attention_shapes
=
None
,
rope_theta
=
10000
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
...
@@ -111,6 +113,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -111,6 +113,7 @@ class QuantAttentionFused(nn.Module):
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
max_seq_len
=
max_seq_len
self
.
max_seq_len
=
max_seq_len
self
.
is_hf_transformers
=
False
self
.
is_hf_transformers
=
False
self
.
rope_theta
=
rope_theta
# attention shapes for self attention
# attention shapes for self attention
self
.
attention_shapes
=
get_attention_shapes
(
self
.
attention_shapes
=
get_attention_shapes
(
...
@@ -127,7 +130,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -127,7 +130,7 @@ class QuantAttentionFused(nn.Module):
self
.
is_neox
=
False
self
.
is_neox
=
False
else
:
else
:
self
.
alibi
=
None
self
.
alibi
=
None
self
.
rope
=
RoPE
(
hidden_size
,
n_heads
,
max_seq_len
,
dev
)
self
.
rope
=
RoPE
(
hidden_size
,
n_heads
,
max_seq_len
,
dev
,
rope_theta
)
self
.
rotary_dim
=
self
.
head_dim
self
.
rotary_dim
=
self
.
head_dim
self
.
is_neox
=
True
self
.
is_neox
=
True
...
@@ -221,7 +224,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -221,7 +224,7 @@ class QuantAttentionFused(nn.Module):
alibi_slopes
,
# alibi slopes
alibi_slopes
,
# alibi slopes
self
.
start_pos
,
# timestep
self
.
start_pos
,
# timestep
self
.
rotary_dim
,
# rotary embedding dimension
self
.
rotary_dim
,
# rotary embedding dimension
10000
,
# rotary embedding base
self
.
rope_theta
,
# rotary embedding base
self
.
is_neox
,
# is neox
self
.
is_neox
,
# is neox
)
)
attention_weight
=
attention_weight
.
reshape
(
bsz
,
1
,
-
1
)
attention_weight
=
attention_weight
.
reshape
(
bsz
,
1
,
-
1
)
...
...
awq/modules/fused/block.py
View file @
727172e9
...
@@ -5,7 +5,7 @@ from awq.modules.fused.attn import QuantAttentionFused
...
@@ -5,7 +5,7 @@ from awq.modules.fused.attn import QuantAttentionFused
class
MixtralBlock
(
nn
.
Module
):
class
MixtralBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
moe
,
norm_1
,
norm_2
,
dev
,
max_seq_len
moe
,
norm_1
,
norm_2
,
dev
,
max_seq_len
,
rope_theta
):
):
super
().
__init__
()
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
...
@@ -14,7 +14,7 @@ class MixtralBlock(nn.Module):
...
@@ -14,7 +14,7 @@ class MixtralBlock(nn.Module):
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
attn
=
QuantAttentionFused
(
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
rope_theta
=
rope_theta
).
to
(
dev
)
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
moe
=
moe
self
.
moe
=
moe
...
@@ -41,7 +41,10 @@ class LlamaLikeBlock(nn.Module):
...
@@ -41,7 +41,10 @@ class LlamaLikeBlock(nn.Module):
LlamaLikeBlock is intended to be reused across blocks that have
LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
"""
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
,
rope_theta
):
super
().
__init__
()
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
n_kv_heads
=
n_kv_heads
...
@@ -49,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
...
@@ -49,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
attn
=
QuantAttentionFused
(
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
rope_theta
=
rope_theta
).
to
(
dev
)
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
...
...
awq/utils/lm_eval_adaptor.py
deleted
100644 → 0
View file @
5b9f3c47
import
transformers
import
torch
from
lm_eval.base
import
BaseLM
import
fnmatch
import
logging
class
LMEvalAdaptor
(
BaseLM
):
def
__init__
(
self
,
model_name
,
model
,
tokenizer
,
device
,
batch_size
=
1
,
max_length
=-
1
):
super
().
__init__
()
assert
isinstance
(
batch_size
,
int
)
self
.
model_name
=
model_name
self
.
model
=
model
.
to
(
device
)
self
.
model
.
eval
()
self
.
tokenizer
=
tokenizer
# assert isinstance(self.tokenizer, (
# transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
# transformers.T5Tokenizer, transformers.T5TokenizerFast,
# )), "this tokenizer has not been checked for compatibility yet!"
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
_batch_size
=
batch_size
self
.
_max_length
=
max_length
@
property
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token_id
@
property
def
max_length
(
self
):
if
self
.
_max_length
!=
-
1
:
return
self
.
_max_length
if
hasattr
(
self
.
model
.
config
,
'n_ctx'
):
return
self
.
model
.
config
.
n_ctx
elif
hasattr
(
self
.
model
.
config
,
'max_position_embeddings'
):
return
self
.
model
.
config
.
max_position_embeddings
elif
hasattr
(
self
.
model
.
config
,
'n_positions'
):
return
self
.
model
.
config
.
n_positions
elif
'bloom'
in
self
.
model_name
:
return
2048
elif
'llama'
in
self
.
model_name
:
return
2048
# TODO: did not check this
elif
'mpt'
in
self
.
model_name
:
return
2048
elif
'falcon'
in
self
.
model_name
:
return
2048
else
:
logging
.
debug
(
self
.
model
.
config
)
raise
NotImplementedError
@
property
def
max_gen_toks
(
self
):
return
256
@
property
def
batch_size
(
self
):
return
self
.
_batch_size
@
property
def
device
(
self
):
return
"cuda"
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
False
)
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
)
def
_model_call
(
self
,
inps
):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with
torch
.
no_grad
():
if
isinstance
(
self
.
model
,
transformers
.
models
.
t5
.
modeling_t5
.
T5ForConditionalGeneration
):
dec_inps
=
torch
.
cat
(
[
torch
.
tensor
(
self
.
model
.
generation_config
.
decoder_start_token_id
,
)
.
tile
(
len
(
inps
),
1
)
.
to
(
inps
),
inps
,
],
dim
=
1
,
)
kwargs
=
{
"decoder_input_ids"
:
dec_inps
,}
else
:
kwargs
=
{}
out
=
self
.
model
(
inps
,
**
kwargs
)[
0
]
if
"opt"
in
self
.
model_name
:
# there are a few extra tokens in opt, which we should omit
return
out
[:,
:,
:
50257
]
else
:
return
out
# [:, :, :self.tokenizer.vocab_size]
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
return
self
.
model
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
examples/eval.py
View file @
727172e9
...
@@ -2,7 +2,6 @@ import argparse
...
@@ -2,7 +2,6 @@ import argparse
from
lm_eval
import
evaluator
from
lm_eval
import
evaluator
from
awq
import
AutoAWQForCausalLM
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.eval_utils
import
evaluate_perplexity
from
awq.utils.eval_utils
import
evaluate_perplexity
def
run_eval
(
def
run_eval
(
...
@@ -26,11 +25,9 @@ def run_eval(
...
@@ -26,11 +25,9 @@ def run_eval(
evaluate_perplexity
(
model
.
model
,
tokenizer
)
evaluate_perplexity
(
model
.
model
,
tokenizer
)
else
:
else
:
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
task_batch_size
)
# Evaluate perplexity of quantized model
# Evaluate perplexity of quantized model
results
=
evaluator
.
simple_evaluate
(
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_
model
,
model
=
model
,
tasks
=
tasks
,
tasks
=
tasks
,
batch_size
=
task_batch_size
,
batch_size
=
task_batch_size
,
no_cache
=
True
,
no_cache
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment