Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
b491c2d6
"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "349a499f339fd0e758ee8bd54c09ee073cfbaf3b"
Commit
b491c2d6
authored
Aug 27, 2023
by
Casper
Browse files
Replace print with logging, Remove uncommented code
parent
f741f406
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
9 additions
and
17 deletions
+9
-17
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+1
-8
awq/modules/fused_mlp.py
awq/modules/fused_mlp.py
+0
-1
awq/modules/fused_norm.py
awq/modules/fused_norm.py
+0
-2
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+2
-2
awq/utils/calib_data.py
awq/utils/calib_data.py
+2
-1
awq/utils/lm_eval_adaptor.py
awq/utils/lm_eval_adaptor.py
+2
-2
awq/utils/parallel.py
awq/utils/parallel.py
+2
-1
No files found.
awq/modules/fused_attn.py
View file @
b491c2d6
...
@@ -34,8 +34,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -34,8 +34,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
# self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
# self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
def
forward
(
...
@@ -46,7 +44,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -46,7 +44,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
):
):
# Apply rotary embedding to the query and key before passing them
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query
=
query
.
contiguous
()
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding_neox
(
awq_inference_engine
.
rotary_embedding_neox
(
...
@@ -146,7 +143,7 @@ def make_quant_attn(model, dev):
...
@@ -146,7 +143,7 @@ def make_quant_attn(model, dev):
qweights
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
1
)
qweights
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
1
)
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
)
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
)
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
# g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
g_idx
=
None
g_idx
=
None
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
...
@@ -156,8 +153,6 @@ def make_quant_attn(model, dev):
...
@@ -156,8 +153,6 @@ def make_quant_attn(model, dev):
qkv_layer
.
scales
=
scales
qkv_layer
.
scales
=
scales
qkv_layer
.
bias
=
bias
qkv_layer
.
bias
=
bias
# We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch.
attn
=
QuantLlamaAttention
(
m
.
hidden_size
,
m
.
num_heads
,
qkv_layer
,
m
.
o_proj
,
dev
)
attn
=
QuantLlamaAttention
(
m
.
hidden_size
,
m
.
num_heads
,
qkv_layer
,
m
.
o_proj
,
dev
)
if
'.'
in
name
:
if
'.'
in
name
:
...
@@ -169,6 +164,4 @@ def make_quant_attn(model, dev):
...
@@ -169,6 +164,4 @@ def make_quant_attn(model, dev):
parent
=
model
parent
=
model
child_name
=
name
child_name
=
name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr
(
parent
,
child_name
,
attn
)
setattr
(
parent
,
child_name
,
attn
)
awq/modules/fused_mlp.py
View file @
b491c2d6
...
@@ -71,7 +71,6 @@ class QuantLlamaMLP(nn.Module):
...
@@ -71,7 +71,6 @@ class QuantLlamaMLP(nn.Module):
def
make_fused_mlp
(
m
,
parent_name
=
''
):
def
make_fused_mlp
(
m
,
parent_name
=
''
):
if
not
hasattr
(
make_fused_mlp
,
"called"
):
if
not
hasattr
(
make_fused_mlp
,
"called"
):
# print("[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.")
make_fused_mlp
.
called
=
True
make_fused_mlp
.
called
=
True
"""
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
...
...
awq/modules/fused_norm.py
View file @
b491c2d6
...
@@ -38,6 +38,4 @@ def make_quant_norm(model):
...
@@ -38,6 +38,4 @@ def make_quant_norm(model):
parent
=
model
parent
=
model
child_name
=
name
child_name
=
name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr
(
parent
,
child_name
,
norm
)
setattr
(
parent
,
child_name
,
norm
)
awq/quantize/auto_scale.py
View file @
b491c2d6
import
gc
import
gc
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
logging
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
...
@@ -154,9 +155,8 @@ def auto_scale_block(awq_model,
...
@@ -154,9 +155,8 @@ def auto_scale_block(awq_model,
best_scales
=
scales
best_scales
=
scales
block
.
load_state_dict
(
org_sd
)
block
.
load_state_dict
(
org_sd
)
if
best_ratio
==
-
1
:
if
best_ratio
==
-
1
:
print
(
history
)
logging
.
debug
(
history
)
raise
Exception
raise
Exception
# print(best_ratio)
best_scales
=
best_scales
.
view
(
-
1
)
best_scales
=
best_scales
.
view
(
-
1
)
assert
torch
.
isnan
(
best_scales
).
sum
()
==
0
,
best_scales
assert
torch
.
isnan
(
best_scales
).
sum
()
==
0
,
best_scales
...
...
awq/utils/calib_data.py
View file @
b491c2d6
import
torch
import
torch
import
logging
from
datasets
import
load_dataset
from
datasets
import
load_dataset
def
get_calib_dataset
(
data
=
"pileval"
,
tokenizer
=
None
,
n_samples
=
512
,
block_size
=
512
):
def
get_calib_dataset
(
data
=
"pileval"
,
tokenizer
=
None
,
n_samples
=
512
,
block_size
=
512
):
...
@@ -25,5 +26,5 @@ def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=
...
@@ -25,5 +26,5 @@ def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=
# now concatenate all samples and split according to block size
# now concatenate all samples and split according to block size
cat_samples
=
torch
.
cat
(
samples
,
dim
=
1
)
cat_samples
=
torch
.
cat
(
samples
,
dim
=
1
)
n_split
=
cat_samples
.
shape
[
1
]
//
block_size
n_split
=
cat_samples
.
shape
[
1
]
//
block_size
print
(
f
" * Split into
{
n_split
}
blocks"
)
logging
.
debug
(
f
" * Split into
{
n_split
}
blocks"
)
return
[
cat_samples
[:,
i
*
block_size
:(
i
+
1
)
*
block_size
]
for
i
in
range
(
n_split
)]
return
[
cat_samples
[:,
i
*
block_size
:(
i
+
1
)
*
block_size
]
for
i
in
range
(
n_split
)]
awq/utils/lm_eval_adaptor.py
View file @
b491c2d6
...
@@ -2,7 +2,7 @@ import transformers
...
@@ -2,7 +2,7 @@ import transformers
import
torch
import
torch
from
lm_eval.base
import
BaseLM
from
lm_eval.base
import
BaseLM
import
fnmatch
import
fnmatch
import
logging
class
LMEvalAdaptor
(
BaseLM
):
class
LMEvalAdaptor
(
BaseLM
):
...
@@ -52,7 +52,7 @@ class LMEvalAdaptor(BaseLM):
...
@@ -52,7 +52,7 @@ class LMEvalAdaptor(BaseLM):
elif
'falcon'
in
self
.
model_name
:
elif
'falcon'
in
self
.
model_name
:
return
2048
return
2048
else
:
else
:
print
(
self
.
model
.
config
)
logging
.
debug
(
self
.
model
.
config
)
raise
NotImplementedError
raise
NotImplementedError
@
property
@
property
...
...
awq/utils/parallel.py
View file @
b491c2d6
import
os
import
os
import
torch
import
torch
import
gc
import
gc
import
logging
def
auto_parallel
(
args
):
def
auto_parallel
(
args
):
...
@@ -23,5 +24,5 @@ def auto_parallel(args):
...
@@ -23,5 +24,5 @@ def auto_parallel(args):
cuda_visible_devices
=
list
(
range
(
8
))
cuda_visible_devices
=
list
(
range
(
8
))
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
[
str
(
dev
)
for
dev
in
cuda_visible_devices
[:
n_gpu
]])
[
str
(
dev
)
for
dev
in
cuda_visible_devices
[:
n_gpu
]])
print
(
"CUDA_VISIBLE_DEVICES: "
,
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
])
logging
.
debug
(
"CUDA_VISIBLE_DEVICES: "
,
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
])
return
cuda_visible_devices
return
cuda_visible_devices
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment