Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
727172e9
Unverified
Commit
727172e9
authored
Dec 23, 2023
by
Casper
Committed by
GitHub
Dec 23, 2023
Browse files
Fused rope theta (#270)
Co-authored-by:
Casper Hansen
<
casperbh96@gmail.com
>
parent
5b9f3c47
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
23 additions
and
132 deletions
+23
-132
awq/models/llama.py
awq/models/llama.py
+2
-1
awq/models/mixtral.py
awq/models/mixtral.py
+3
-3
awq/models/yi.py
awq/models/yi.py
+2
-1
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+8
-5
awq/modules/fused/block.py
awq/modules/fused/block.py
+7
-4
awq/utils/lm_eval_adaptor.py
awq/utils/lm_eval_adaptor.py
+0
-114
examples/eval.py
examples/eval.py
+1
-4
No files found.
awq/models/llama.py
View file @
727172e9
...
...
@@ -118,7 +118,8 @@ class LlamaFuser:
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
rope_theta
=
self
.
model
.
config
.
rope_theta
))
self
.
model
.
model
=
LlamaLikeModel
(
...
...
awq/models/mixtral.py
View file @
727172e9
...
...
@@ -18,8 +18,7 @@ class MixtralAWQForCausalLM(BaseAWQForCausalLM):
@
staticmethod
def
fuse_layers
(
model
:
OldMixtralForCausalLM
):
fuser
=
MixtralFuser
(
model
)
# TODO: Fix perplexity on fusing Mixtral
#fuser.fuse_transformer()
fuser
.
fuse_transformer
()
@
staticmethod
def
get_model_layers
(
model
:
OldMixtralForCausalLM
):
...
...
@@ -125,7 +124,8 @@ class MixtralFuser:
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
rope_theta
=
self
.
model
.
config
.
rope_theta
))
self
.
model
.
model
=
MixtralModel
(
...
...
awq/models/yi.py
View file @
727172e9
...
...
@@ -113,7 +113,8 @@ class YiFuser:
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
rope_theta
=
self
.
model
.
config
.
rope_theta
))
self
.
model
.
model
=
LlamaLikeModel
(
...
...
awq/modules/fused/attn.py
View file @
727172e9
...
...
@@ -23,11 +23,13 @@ if HF_NEW_CACHE_FORMAT:
class
RoPE
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
max_seq_len
,
device
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
max_seq_len
,
device
,
rope_theta
):
super
(
RoPE
,
self
).
__init__
()
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
precompute_freqs_cis
(
hidden_size
//
n_heads
,
max_seq_len
*
2
).
to
(
device
),
self
.
precompute_freqs_cis
(
hidden_size
//
n_heads
,
max_seq_len
*
2
,
rope_theta
).
to
(
device
),
requires_grad
=
False
)
...
...
@@ -97,7 +99,7 @@ class ALiBi(nn.Module):
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
None
):
use_alibi
=
False
,
attention_shapes
=
None
,
rope_theta
=
10000
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
n_heads
=
n_heads
...
...
@@ -111,6 +113,7 @@ class QuantAttentionFused(nn.Module):
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
max_seq_len
=
max_seq_len
self
.
is_hf_transformers
=
False
self
.
rope_theta
=
rope_theta
# attention shapes for self attention
self
.
attention_shapes
=
get_attention_shapes
(
...
...
@@ -127,7 +130,7 @@ class QuantAttentionFused(nn.Module):
self
.
is_neox
=
False
else
:
self
.
alibi
=
None
self
.
rope
=
RoPE
(
hidden_size
,
n_heads
,
max_seq_len
,
dev
)
self
.
rope
=
RoPE
(
hidden_size
,
n_heads
,
max_seq_len
,
dev
,
rope_theta
)
self
.
rotary_dim
=
self
.
head_dim
self
.
is_neox
=
True
...
...
@@ -221,7 +224,7 @@ class QuantAttentionFused(nn.Module):
alibi_slopes
,
# alibi slopes
self
.
start_pos
,
# timestep
self
.
rotary_dim
,
# rotary embedding dimension
10000
,
# rotary embedding base
self
.
rope_theta
,
# rotary embedding base
self
.
is_neox
,
# is neox
)
attention_weight
=
attention_weight
.
reshape
(
bsz
,
1
,
-
1
)
...
...
awq/modules/fused/block.py
View file @
727172e9
...
...
@@ -5,7 +5,7 @@ from awq.modules.fused.attn import QuantAttentionFused
class
MixtralBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
moe
,
norm_1
,
norm_2
,
dev
,
max_seq_len
moe
,
norm_1
,
norm_2
,
dev
,
max_seq_len
,
rope_theta
):
super
().
__init__
()
self
.
n_heads
=
n_heads
...
...
@@ -14,7 +14,7 @@ class MixtralBlock(nn.Module):
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
rope_theta
=
rope_theta
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
moe
=
moe
...
...
@@ -41,7 +41,10 @@ class LlamaLikeBlock(nn.Module):
LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
,
rope_theta
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
...
...
@@ -49,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
rope_theta
=
rope_theta
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
...
...
awq/utils/lm_eval_adaptor.py
deleted
100644 → 0
View file @
5b9f3c47
import
transformers
import
torch
from
lm_eval.base
import
BaseLM
import
fnmatch
import
logging
class
LMEvalAdaptor
(
BaseLM
):
def
__init__
(
self
,
model_name
,
model
,
tokenizer
,
device
,
batch_size
=
1
,
max_length
=-
1
):
super
().
__init__
()
assert
isinstance
(
batch_size
,
int
)
self
.
model_name
=
model_name
self
.
model
=
model
.
to
(
device
)
self
.
model
.
eval
()
self
.
tokenizer
=
tokenizer
# assert isinstance(self.tokenizer, (
# transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
# transformers.T5Tokenizer, transformers.T5TokenizerFast,
# )), "this tokenizer has not been checked for compatibility yet!"
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
_batch_size
=
batch_size
self
.
_max_length
=
max_length
@
property
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token_id
@
property
def
max_length
(
self
):
if
self
.
_max_length
!=
-
1
:
return
self
.
_max_length
if
hasattr
(
self
.
model
.
config
,
'n_ctx'
):
return
self
.
model
.
config
.
n_ctx
elif
hasattr
(
self
.
model
.
config
,
'max_position_embeddings'
):
return
self
.
model
.
config
.
max_position_embeddings
elif
hasattr
(
self
.
model
.
config
,
'n_positions'
):
return
self
.
model
.
config
.
n_positions
elif
'bloom'
in
self
.
model_name
:
return
2048
elif
'llama'
in
self
.
model_name
:
return
2048
# TODO: did not check this
elif
'mpt'
in
self
.
model_name
:
return
2048
elif
'falcon'
in
self
.
model_name
:
return
2048
else
:
logging
.
debug
(
self
.
model
.
config
)
raise
NotImplementedError
@
property
def
max_gen_toks
(
self
):
return
256
@
property
def
batch_size
(
self
):
return
self
.
_batch_size
@
property
def
device
(
self
):
return
"cuda"
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
False
)
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
)
def
_model_call
(
self
,
inps
):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with
torch
.
no_grad
():
if
isinstance
(
self
.
model
,
transformers
.
models
.
t5
.
modeling_t5
.
T5ForConditionalGeneration
):
dec_inps
=
torch
.
cat
(
[
torch
.
tensor
(
self
.
model
.
generation_config
.
decoder_start_token_id
,
)
.
tile
(
len
(
inps
),
1
)
.
to
(
inps
),
inps
,
],
dim
=
1
,
)
kwargs
=
{
"decoder_input_ids"
:
dec_inps
,}
else
:
kwargs
=
{}
out
=
self
.
model
(
inps
,
**
kwargs
)[
0
]
if
"opt"
in
self
.
model_name
:
# there are a few extra tokens in opt, which we should omit
return
out
[:,
:,
:
50257
]
else
:
return
out
# [:, :, :self.tokenizer.vocab_size]
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
return
self
.
model
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
examples/eval.py
View file @
727172e9
...
...
@@ -2,7 +2,6 @@ import argparse
from
lm_eval
import
evaluator
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.eval_utils
import
evaluate_perplexity
def
run_eval
(
...
...
@@ -26,11 +25,9 @@ def run_eval(
evaluate_perplexity
(
model
.
model
,
tokenizer
)
else
:
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
task_batch_size
)
# Evaluate perplexity of quantized model
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_
model
,
model
=
model
,
tasks
=
tasks
,
batch_size
=
task_batch_size
,
no_cache
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment