Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
b0a0fecf
Unverified
Commit
b0a0fecf
authored
Jul 20, 2023
by
Jiaming Tang
Committed by
GitHub
Jul 20, 2023
Browse files
Merge pull request #41 from mit-han-lab/dev/more_models
parents
25e92c4c
ce4a6bb1
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
619 additions
and
58 deletions
+619
-58
awq/entry.py
awq/entry.py
+68
-17
awq/kernels/gemm_cuda_gen.cu
awq/kernels/gemm_cuda_gen.cu
+232
-15
awq/kernels/setup.py
awq/kernels/setup.py
+2
-2
awq/quantize/auto_clip.py
awq/quantize/auto_clip.py
+6
-2
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+149
-5
awq/quantize/pre_quant.py
awq/quantize/pre_quant.py
+41
-2
awq/quantize/qmodule.py
awq/quantize/qmodule.py
+14
-0
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+46
-15
awq/utils/lm_eval_adaptor.py
awq/utils/lm_eval_adaptor.py
+4
-0
awq/utils/module.py
awq/utils/module.py
+14
-0
awq/utils/utils.py
awq/utils/utils.py
+43
-0
No files found.
awq/entry.py
View file @
b0a0fecf
from
lm_eval
import
evaluator
,
tasks
from
lm_eval
import
evaluator
,
tasks
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
import
torch
import
torch
import
argparse
import
argparse
import
os
import
os
import
json
import
json
from
accelerate
import
init_empty_weights
,
load_checkpoint_
and_dispatch
from
accelerate
import
init_empty_weights
,
infer_auto_device_map
,
dispatch_model
,
load_checkpoint_
in_model
from
awq.utils.parallel
import
auto_parallel
from
awq.utils.parallel
import
auto_parallel
from
awq.quantize.pre_quant
import
run_awq
,
apply_awq
from
awq.quantize.pre_quant
import
run_awq
,
apply_awq
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.utils
import
simple_dispatch_model
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -20,6 +21,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
...
@@ -20,6 +21,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
# model config
# model config
parser
.
add_argument
(
'--parallel'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--parallel'
,
action
=
'store_true'
,
help
=
"enable model parallelism"
)
help
=
"enable model parallelism"
)
# max memory to offload larger models to CPU
parser
.
add_argument
(
'--max_memory'
,
type
=
str
,
nargs
=
'*'
,
help
=
"List of device_id:max_memory pairs to be parsed into a dictionary; "
\
+
"Example: 0:10GiB 1:10GiB cpu:30GiB; "
\
+
"mode details here: "
\
+
"https://huggingface.co/docs/accelerate/usage_guides/big_modeling"
)
parser
.
add_argument
(
'--auto_parallel'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--auto_parallel'
,
action
=
'store_true'
,
help
=
"automatically set parallel and batch_size"
)
help
=
"automatically set parallel and batch_size"
)
# quantization config
# quantization config
...
@@ -43,6 +50,9 @@ parser.add_argument('--load_awq', type=str, default=None,
...
@@ -43,6 +50,9 @@ parser.add_argument('--load_awq', type=str, default=None,
help
=
"load the awq search results"
)
help
=
"load the awq search results"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
max_memory
=
[
v
.
split
(
':'
)
for
v
in
(
args
.
max_memory
or
[])]
max_memory
=
{(
int
(
k
)
if
k
.
isdigit
()
else
k
):
v
for
k
,
v
in
max_memory
}
if
args
.
auto_parallel
:
if
args
.
auto_parallel
:
gpu_list
=
auto_parallel
(
args
)
gpu_list
=
auto_parallel
(
args
)
...
@@ -62,39 +72,67 @@ def build_model_and_enc(model_path):
...
@@ -62,39 +72,67 @@ def build_model_and_enc(model_path):
print
(
f
"* Building model
{
model_path
}
"
)
print
(
f
"* Building model
{
model_path
}
"
)
# all hf model
# all hf model
config
=
AutoConfig
.
from_pretrained
(
model_path
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
enc
=
AutoTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
)
if
"mpt"
in
config
.
__class__
.
__name__
.
lower
():
enc
=
AutoTokenizer
.
from_pretrained
(
config
.
tokenizer_name
,
trust_remote_code
=
True
)
else
:
enc
=
AutoTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
,
trust_remote_code
=
True
)
if
args
.
load_quant
:
# directly load quantized weights
if
args
.
load_quant
:
# directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
print
(
"Loading pre-computed quantized weights..."
)
print
(
"Loading pre-computed quantized weights..."
)
with
init_empty_weights
():
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_
pretrained
(
model_path
,
config
=
config
,
model
=
AutoModelForCausalLM
.
from_
config
(
config
=
config
,
torch_dtype
=
torch
.
float16
)
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
real_quantize_model_weight
(
real_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
=
load_checkpoint_and_dispatch
(
model
,
args
.
load_quant
,
device_map
=
"balanced"
,
model
.
tie_weights
()
# TODO: can we remove this?
# Infer device map
kwargs
=
{
"max_memory"
:
max_memory
}
if
len
(
max_memory
)
else
{}
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
]
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
)
else
:
# fp16 to quantized
# Load checkpoint in the model
kwargs
=
{
"device_map"
:
"balanced"
,
"torch_dtype"
:
torch
.
float16
}
load_checkpoint_in_model
(
model
,
checkpoint
=
args
.
load_quant
,
device_map
=
device_map
,
offload_state_dict
=
True
,
)
# Dispatch model
model
=
simple_dispatch_model
(
model
,
device_map
=
device_map
)
model
.
eval
()
else
:
# fp16 to quantized
args
.
run_awq
&=
not
args
.
load_awq
# if load_awq, no need to run awq
# Init model on CPU:
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"low_cpu_mem_usage"
:
True
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
**
kwargs
)
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model
.
eval
()
if
args
.
run_awq
:
if
args
.
run_awq
:
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
awq_results
=
run_awq
(
awq_results
=
run_awq
(
model
,
enc
,
model
,
enc
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
n_samples
=
128
,
seqlen
=
512
,
n_samples
=
128
,
seqlen
=
512
,
)
)
if
args
.
dump_awq
:
if
args
.
dump_awq
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_awq
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
awq_results
,
args
.
dump_awq
)
torch
.
save
(
awq_results
,
args
.
dump_awq
)
print
(
"AWQ results saved at"
,
args
.
dump_awq
)
print
(
"AWQ results saved at"
,
args
.
dump_awq
)
exit
(
0
)
if
args
.
load_awq
:
if
args
.
load_awq
:
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
...
@@ -113,6 +151,9 @@ def build_model_and_enc(model_path):
...
@@ -113,6 +151,9 @@ def build_model_and_enc(model_path):
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
)
)
if
args
.
dump_quant
:
if
args
.
dump_quant
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_quant
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
print
(
print
(
f
"Saving the quantized model at
{
args
.
dump_quant
}
..."
)
f
"Saving the quantized model at
{
args
.
dump_quant
}
..."
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
args
.
dump_quant
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
args
.
dump_quant
)
...
@@ -120,6 +161,17 @@ def build_model_and_enc(model_path):
...
@@ -120,6 +161,17 @@ def build_model_and_enc(model_path):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs
=
{
"max_memory"
:
max_memory
}
if
len
(
max_memory
)
else
{}
device_map
=
infer_auto_device_map
(
model
,
# TODO: can we remove this?
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
model
=
dispatch_model
(
model
,
device_map
=
device_map
)
return
model
,
enc
return
model
,
enc
...
@@ -136,11 +188,10 @@ def main():
...
@@ -136,11 +188,10 @@ def main():
# a hack here to auto set model group
# a hack here to auto set model group
model
,
enc
=
build_model_and_enc
(
args
.
model_path
)
model
,
enc
=
build_model_and_enc
(
args
.
model_path
)
lm_eval_model
=
LMEvalAdaptor
(
args
.
model_path
,
model
,
enc
,
args
.
batch_size
)
if
args
.
tasks
is
not
None
:
if
args
.
tasks
is
not
None
:
task_names
=
args
.
tasks
.
split
(
","
)
task_names
=
args
.
tasks
.
split
(
","
)
lm_eval_model
=
LMEvalAdaptor
(
args
.
model_path
,
model
,
enc
,
args
.
batch_size
)
results
=
evaluator
.
simple_evaluate
(
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_model
,
model
=
lm_eval_model
,
tasks
=
task_names
,
tasks
=
task_names
,
...
...
awq/kernels/gemm_cuda_gen.cu
View file @
b0a0fecf
...
@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) {
...
@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) {
return
(
v1
<<
16
)
|
v0
;
return
(
v1
<<
16
)
|
v0
;
}
}
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
{
static
constexpr
uint32_t
ZERO
=
0x0
;
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
float
C_warp
[
32
];
...
@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
__shared__
half
zeros_shared
[
128
];
__shared__
half
zeros_shared
[
128
];
int
j_factors1
=
((
OC
+
128
-
1
)
/
128
);
int
j_factors1
=
((
OC
+
128
-
1
)
/
128
);
int
blockIdx_x
=
0
;
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
...
@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+
(((
int
)
threadIdx
.
x
)
/
(
128
/
8
))
*
(
OC
/
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
128
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
1
;
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
...
@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// preload s.f. and zeros
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
if
((
k_bound
-
1
)
*
32
+
blockIdx_z
>=
IC
)
k_bound
-=
1
;
if
((
k_bound
-
1
)
*
split_k_iters
*
32
+
blockIdx_z
*
32
>=
IC
)
k_bound
-=
1
;
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
__syncthreads
();
__syncthreads
();
...
@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
}
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
128
*
(
OC
/
8
));
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
G
*
(
OC
/
8
));
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
128
*
(
OC
));
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
/*
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
...
@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
8
;
++
ax0_ax1_fused_0
)
{
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
8
;
++
ax0_ax1_fused_0
)
{
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each warp: 32 x 4
...
@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
}
}
}
}
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n64k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
B_shared
[
32
*
(
64
+
8
)];
__shared__
half
scaling_factors_shared
[
64
];
__shared__
half
zeros_shared
[
64
];
int
j_factors1
=
((
OC
+
64
-
1
)
/
64
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
half
A_shared_warp
[
8
];
half
B_shared_warp
[
16
];
for
(
int
j_0_4_init
=
0
;
j_0_4_init
<
2
;
++
j_0_4_init
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
C_warp
[(
j_0_4_init
*
8
)
+
i
]
=
0.0
;
}
}
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
64
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
64
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half
*
A_ptr
=
A
+
(((
int
)
blockIdx_y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
4
+
(((
int
)
threadIdx
.
x
)
/
(
64
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
)
)
*
8
;
half
*
B_shared_ptr
=
B_shared
+
((
int
)
threadIdx
.
y
)
*
(
row_stride
/
2
)
*
(
64
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
64
/
8
))
*
(
64
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
64
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
8
;
half
*
C_ptr
=
C
+
blockIdx_z
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
64
+
((
int
)
threadIdx
.
y
)
*
32
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
if
((
k_bound
-
1
)
*
split_k_iters
*
32
+
blockIdx_z
*
32
>=
IC
)
k_bound
-=
1
;
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
__syncthreads
();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if
(
ld_A_flag
)
{
*
(
uint4
*
)(
A_shared_ptr
)
=
*
(
uint4
*
)(
A_ptr
+
(
k_0_0
*
32
));
}
else
{
*
(
uint4
*
)(
A_shared_ptr
)
=
make_uint4
(
0
,
0
,
0
,
0
);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
G
*
(
OC
/
8
));
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
4
;
++
ax0_ax1_fused_0
)
{
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*
(
uint4
*
)(
B_shared_ptr
+
ax0_ax1_fused_0
*
row_stride
*
(
64
+
8
))
=
B_loaded_fp16
;
}
__syncthreads
();
for
(
int
k_0_1
=
0
;
k_0_1
<
2
;
++
k_0_1
)
{
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
:
"r"
(
addr
)
);
}
for
(
int
ax1_0
=
0
;
ax1_0
<
2
;
++
ax1_0
)
{
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
1152
)
+
(((
int
)
threadIdx
.
y
)
*
32
))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
72
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
:
"r"
(
addr
)
);
}
}
for
(
int
j_0_4
=
0
;
j_0_4
<
2
;
++
j_0_4
)
{
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for
(
int
ax1_0_1
=
0
;
ax1_0_1
<
2
;
++
ax1_0_1
)
{
for
(
int
local_id
=
0
;
local_id
<
8
;
++
local_id
)
{
int
row_offset
=
(((
int
)
blockIdx_y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
}
}
}
}
// in_feats: M, IC [float16]
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// scaling_factors: IC // G, OC [float16]
...
@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda(
...
@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda(
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
if
(
num_out_channels
%
64
!=
0
)
if
(
num_out_channels
%
128
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 128"
);
if
(
num_out_channels
%
8
!=
0
)
if
(
num_out_channels
%
8
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
if
(
group_size
%
32
!=
0
)
throw
std
::
invalid_argument
(
"Group size should be a multiple of 32"
);
if
(
num_out_channels
%
group_size
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of Group size"
);
if
(
num_out_channels
%
128
==
0
)
{
int
j_factors1
=
num_out_channels
/
128
/
1
;
int
j_factors1
=
num_out_channels
/
128
/
1
;
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
{
int
j_factors1
=
num_out_channels
/
64
/
1
;
dim3
num_blocks
(
1
*
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
return
_out_feats
.
sum
(
0
);
}
}
awq/kernels/setup.py
View file @
b0a0fecf
...
@@ -3,7 +3,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtensio
...
@@ -3,7 +3,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtensio
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
],
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
,
"-keep"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
],
}
}
setup
(
setup
(
...
...
awq/quantize/auto_clip.py
View file @
b0a0fecf
...
@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
...
@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
# prevent OOM
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
w_all
=
w
best_max_val_all
=
[]
best_max_val_all
=
[]
...
@@ -73,11 +73,13 @@ def auto_clip_block(module,
...
@@ -73,11 +73,13 @@ def auto_clip_block(module,
clip_list
=
[]
clip_list
=
[]
for
name
in
named_linears
:
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
[
"q_"
,
"k_"
]]):
if
any
([
_
in
name
for
_
in
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]]):
continue
continue
named_linears
[
name
].
cuda
()
max_val
=
auto_clip_layer
(
max_val
=
auto_clip_layer
(
named_linears
[
name
].
weight
,
input_feat
[
name
],
n_bit
=
w_bit
,
q_config
=
q_config
)
named_linears
[
name
].
weight
,
input_feat
[
name
],
n_bit
=
w_bit
,
q_config
=
q_config
)
clip_list
.
append
((
name
,
max_val
))
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
return
clip_list
return
clip_list
...
@@ -86,8 +88,10 @@ def apply_clip(module, clip_list):
...
@@ -86,8 +88,10 @@ def apply_clip(module, clip_list):
from
..utils.module
import
get_op_by_name
from
..utils.module
import
get_op_by_name
for
name
,
max_val
in
clip_list
:
for
name
,
max_val
in
clip_list
:
layer
=
get_op_by_name
(
module
,
name
)
layer
=
get_op_by_name
(
module
,
name
)
layer
.
cuda
()
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
org_shape
=
layer
.
weight
.
shape
org_shape
=
layer
.
weight
.
shape
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
layer
.
weight
.
data
=
torch
.
clamp
(
layer
.
weight
.
data
,
-
max_val
,
max_val
)
layer
.
weight
.
data
=
torch
.
clamp
(
layer
.
weight
.
data
,
-
max_val
,
max_val
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
org_shape
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
org_shape
)
layer
.
cpu
()
awq/quantize/auto_scale.py
View file @
b0a0fecf
import
gc
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
..utils.module
import
get_op_by_name
,
get_op_name
from
.qmodule
import
ScaledActivation
from
..utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
...
@@ -32,6 +35,13 @@ def scale_ln_fcs(ln, fcs, scales):
...
@@ -32,6 +35,13 @@ def scale_ln_fcs(ln, fcs, scales):
scales
=
scales
.
to
(
ln
.
weight
.
device
)
scales
=
scales
.
to
(
ln
.
weight
.
device
)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln
.
weight
.
div_
(
scales
)
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
ln
.
bias
.
div_
(
scales
)
...
@@ -50,11 +60,12 @@ def scale_ln_fcs(ln, fcs, scales):
...
@@ -50,11 +60,12 @@ def scale_ln_fcs(ln, fcs, scales):
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
assert
fc1
.
out_features
==
fc2
.
in_features
#
assert fc1.out_features == fc2.in_features
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
fc1
.
weight
.
div_
(
scales
.
view
(
-
1
,
1
))
# fc1.weight.div_(scales.view(-1, 1))
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
...
@@ -66,6 +77,17 @@ def scale_fc_fc(fc1, fc2, scales):
...
@@ -66,6 +77,17 @@ def scale_fc_fc(fc1, fc2, scales):
assert
torch
.
isnan
(
p
).
sum
()
==
0
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
isinstance
(
gelu
,
nn
.
GELU
)
or
isinstance
(
gelu
,
BloomGelu
)
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
auto_scale_block
(
module
,
module_kwargs
,
def
auto_scale_block
(
module
,
module_kwargs
,
w_bit
,
q_config
,
w_bit
,
q_config
,
...
@@ -86,11 +108,15 @@ def auto_scale_block(module, module_kwargs,
...
@@ -86,11 +108,15 @@ def auto_scale_block(module, module_kwargs,
def
_search_module_scale
(
block
,
linears2scale
:
list
,
x
,
kwargs
=
{}):
def
_search_module_scale
(
block
,
linears2scale
:
list
,
x
,
kwargs
=
{}):
# w: co, ci
# w: co, ci
# x: n, ci
# x: n, ci
x
=
x
.
to
(
next
(
block
.
parameters
()).
device
)
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
w_max
=
get_weight_scale
(
w_max
=
get_weight_scale
(
weight
,
q_group_size
=
q_config
.
get
(
"q_group_size"
,
-
1
))
weight
,
q_group_size
=
q_config
.
get
(
"q_group_size"
,
-
1
))
# Clear GPU memory
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
x
=
x
.
to
(
next
(
block
.
parameters
()).
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
org_out
=
block
(
x
,
**
kwargs
)
org_out
=
block
(
x
,
**
kwargs
)
if
isinstance
(
org_out
,
tuple
):
if
isinstance
(
org_out
,
tuple
):
...
@@ -112,7 +138,7 @@ def auto_scale_block(module, module_kwargs,
...
@@ -112,7 +138,7 @@ def auto_scale_block(module, module_kwargs,
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
for
fc
in
linears2scale
:
for
fc
in
linears2scale
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
)
.
to
(
fc
.
weight
.
device
)
)
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
out
=
block
(
x
,
**
kwargs
)
out
=
block
(
x
,
**
kwargs
)
...
@@ -143,6 +169,7 @@ def auto_scale_block(module, module_kwargs,
...
@@ -143,6 +169,7 @@ def auto_scale_block(module, module_kwargs,
module2inspect
=
layers
[
0
]
module2inspect
=
layers
[
0
]
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
scales
.
detach
().
cpu
()
# prev_op_name, [layer_name], scale
# prev_op_name, [layer_name], scale
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
...
@@ -205,6 +232,109 @@ def auto_scale_block(module, module_kwargs,
...
@@ -205,6 +232,109 @@ def auto_scale_block(module, module_kwargs,
inp
=
input_feat
[
'mlp.down_proj'
],
inp
=
input_feat
[
'mlp.down_proj'
],
))
))
elif
isinstance
(
module
,
BloomBlock
):
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
gelu_impl
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
elif
"mpt"
in
str
(
module
.
__class__
).
lower
():
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
norm_1
,
layers
=
[
module
.
attn
.
Wqkv
],
inp
=
input_feat
[
'attn.Wqkv'
],
module2inspect
=
module
.
attn
,
kwargs
=
module_kwargs
,
))
# attn out
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
attn
.
Wqkv
,
layers
=
[
module
.
attn
.
out_proj
],
inp
=
input_feat
[
'attn.out_proj'
],
))
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
norm_2
,
layers
=
[
module
.
ffn
.
up_proj
],
inp
=
input_feat
[
'ffn.up_proj'
],
module2inspect
=
module
.
ffn
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ffn
.
act
,
layers
=
[
module
.
ffn
.
down_proj
],
inp
=
input_feat
[
'ffn.down_proj'
],
))
elif
"falcon"
in
str
(
module
.
__class__
).
lower
():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
if
"falcon-7b"
in
str
(
module
.
__class__
).
lower
():
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
,
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
elif
"falcon-40b"
in
str
(
module
.
__class__
).
lower
():
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ln_attn
,
layers
=
[
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ln_mlp
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
else
:
raise
NotImplementedError
(
"Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported"
)
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
act
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
else
:
else
:
raise
NotImplementedError
(
f
"
{
type
(
module
)
}
not supported yet!"
)
raise
NotImplementedError
(
f
"
{
type
(
module
)
}
not supported yet!"
)
...
@@ -215,11 +345,20 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -215,11 +345,20 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
prev_op
.
cuda
()
for
layer
in
layers
:
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
if
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
isinstance
(
prev_op
,
nn
.
GELU
)
or
isinstance
(
prev_op
,
BloomGelu
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
...
@@ -229,3 +368,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -229,3 +368,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for
layer_name
in
layer_names
:
for
layer_name
in
layer_names
:
inp
=
input_feat_dict
[
layer_name
]
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
prev_op
.
cpu
()
for
layer
in
layers
:
layer
.
cpu
()
scales
.
cpu
()
awq/quantize/pre_quant.py
View file @
b0a0fecf
...
@@ -5,6 +5,7 @@ import gc
...
@@ -5,6 +5,7 @@ import gc
import
functools
import
functools
from
collections
import
defaultdict
from
collections
import
defaultdict
from
transformers.models.bloom.modeling_bloom
import
BloomForCausalLM
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
...
@@ -23,10 +24,32 @@ def get_blocks(model):
...
@@ -23,10 +24,32 @@ def get_blocks(model):
layers
=
model
.
model
.
layers
layers
=
model
.
model
.
layers
elif
isinstance
(
model
,
OPTForCausalLM
):
elif
isinstance
(
model
,
OPTForCausalLM
):
layers
=
model
.
model
.
decoder
.
layers
layers
=
model
.
model
.
decoder
.
layers
elif
isinstance
(
model
,
BloomForCausalLM
):
layers
=
model
.
transformer
.
h
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
blocks
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
h
else
:
else
:
raise
NotImplementedError
(
type
(
model
))
raise
NotImplementedError
(
type
(
model
))
return
layers
return
layers
def
move_embed
(
model
,
device
):
if
isinstance
(
model
,
LlamaForCausalLM
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
elif
isinstance
(
model
,
OPTForCausalLM
):
model
.
model
.
decoder
.
embed_tokens
=
model
.
model
.
decoder
.
embed_tokens
.
to
(
device
)
model
.
model
.
decoder
.
embed_positions
=
model
.
model
.
decoder
.
embed_positions
.
to
(
device
)
elif
isinstance
(
model
,
BloomForCausalLM
):
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
model
.
transformer
.
word_embeddings_layernorm
=
model
.
transformer
.
word_embeddings_layernorm
.
to
(
device
)
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
else
:
raise
NotImplementedError
(
type
(
model
))
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
run_awq
(
def
run_awq
(
...
@@ -50,6 +73,9 @@ def run_awq(
...
@@ -50,6 +73,9 @@ def run_awq(
inps
=
[]
inps
=
[]
layer_kwargs
=
{}
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
move_embed
(
model
,
"cuda"
)
# get input and kwargs to layer 0
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
# use this Catcher hack for now
...
@@ -69,9 +95,13 @@ def run_awq(
...
@@ -69,9 +95,13 @@ def run_awq(
model
(
samples
.
to
(
next
(
model
.
parameters
()).
device
))
model
(
samples
.
to
(
next
(
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
except
ValueError
:
# work with early exit
pass
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
move_embed
(
model
,
"cpu"
)
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -83,6 +113,7 @@ def run_awq(
...
@@ -83,6 +113,7 @@ def run_awq(
# solve layer by layer
# solve layer by layer
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"Running AWQ..."
):
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"Running AWQ..."
):
layer
=
layers
[
i
]
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
named_linears
=
get_named_linears
(
layer
)
# firstly, get input features of all linear layers
# firstly, get input features of all linear layers
...
@@ -102,20 +133,26 @@ def run_awq(
...
@@ -102,20 +133,26 @@ def run_awq(
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
for
h
in
handles
:
for
h
in
handles
:
h
.
remove
()
h
.
remove
()
# now solve for scaling and clipping
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
scales_list
=
auto_scale_block
(
scales_list
=
auto_scale_block
(
layer
,
layer_kwargs
,
layer
,
layer_kwargs
,
w_bit
=
w_bit
,
q_config
=
q_config
,
w_bit
=
w_bit
,
q_config
=
q_config
,
input_feat
=
input_feat
,
input_feat
=
input_feat
,
)
)
apply_scale
(
layer
,
scales_list
,
input_feat_dict
=
input_feat
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
model
,
layer
)
+
"."
)
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
model
,
layer
)
+
"."
)
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
mse_range
:
if
mse_range
:
clip_list
=
auto_clip_block
(
layer
,
clip_list
=
auto_clip_block
(
layer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
w_bit
=
w_bit
,
q_config
=
q_config
,
...
@@ -124,6 +161,8 @@ def run_awq(
...
@@ -124,6 +161,8 @@ def run_awq(
# append prefix to make names global
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
del
input_feat
del
input_feat
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
awq/quantize/qmodule.py
View file @
b0a0fecf
...
@@ -4,6 +4,16 @@ import torch.nn as nn
...
@@ -4,6 +4,16 @@ import torch.nn as nn
import
f16s4_gemm
# with CUDA kernels
import
f16s4_gemm
# with CUDA kernels
class
ScaledActivation
(
nn
.
Module
):
def
__init__
(
self
,
module
,
scales
):
super
().
__init__
()
self
.
act
=
module
self
.
scales
=
nn
.
Parameter
(
scales
.
data
)
def
forward
(
self
,
x
):
return
self
.
act
(
x
)
/
self
.
scales
.
view
(
1
,
1
,
-
1
).
to
(
x
.
device
)
class
WQLinear
(
nn
.
Module
):
class
WQLinear
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
super
().
__init__
()
...
@@ -83,3 +93,7 @@ class WQLinear(nn.Module):
...
@@ -83,3 +93,7 @@ class WQLinear(nn.Module):
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
return
out
.
reshape
(
out_shape
)
def
extra_repr
(
self
)
->
str
:
return
'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
,
self
.
w_bit
,
self
.
group_size
)
awq/quantize/quantizer.py
View file @
b0a0fecf
...
@@ -2,11 +2,48 @@ import torch
...
@@ -2,11 +2,48 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
gc
import
gc
from
.qmodule
import
ScaledActivation
from
..utils.module
import
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
EMBEDDING_KEYWORDS
=
[
"embed"
]
EMBEDDING_KEYWORDS
=
[
"embed"
]
LM_HEAD_KEYWORDS
=
[
"lm_head"
,
"embed_out"
,
"output"
]
LM_HEAD_KEYWORDS
=
[
"lm_head"
,
"embed_out"
,
"output"
]
def
scale_activations
(
module
):
param
=
next
(
module
.
parameters
())
dtype
=
param
.
dtype
device
=
param
.
device
if
isinstance
(
module
,
BloomBlock
):
if
isinstance
(
module
.
mlp
.
gelu_impl
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
gelu_impl
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.gelu_impl"
,
act
)
elif
'mptblock'
in
str
(
module
.
__class__
.
__name__
).
lower
():
if
isinstance
(
module
.
ffn
.
act
,
ScaledActivation
):
return
c
=
module
.
ffn
.
up_proj
.
out_features
act
=
ScaledActivation
(
module
.
ffn
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"ffn.act"
,
act
)
elif
'falcon'
in
str
(
module
.
__class__
).
lower
():
if
isinstance
(
module
.
mlp
.
act
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.act"
,
act
)
# core quantization method (simulated quantization)
# core quantization method (simulated quantization)
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
zero_point
=
True
,
q_group_size
=-
1
,
zero_point
=
True
,
q_group_size
=-
1
,
...
@@ -61,7 +98,9 @@ def pseudo_quantize_model_weight(
...
@@ -61,7 +98,9 @@ def pseudo_quantize_model_weight(
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"pseudo weight quantization..."
):
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"pseudo weight quantization..."
):
named_linears
=
get_named_linears
(
layers
[
i
])
named_linears
=
get_named_linears
(
layers
[
i
])
for
n
,
m
in
named_linears
.
items
():
for
n
,
m
in
named_linears
.
items
():
m
.
cuda
()
m
.
weight
.
data
=
pseudo_quantize_tensor
(
m
.
weight
.
data
,
n_bit
=
w_bit
,
**
q_config
)
m
.
weight
.
data
=
pseudo_quantize_tensor
(
m
.
weight
.
data
,
n_bit
=
w_bit
,
**
q_config
)
m
.
cpu
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -77,29 +116,21 @@ def real_quantize_model_weight(
...
@@ -77,29 +116,21 @@ def real_quantize_model_weight(
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
+
(
"(init only)"
if
init_only
else
""
)):
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
+
(
"(init only)"
if
init_only
else
""
)):
layer
=
layers
[
i
]
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
named_linears
=
get_named_linears
(
layer
)
scale_activations
(
layer
)
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
if
init_only
:
if
init_only
:
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
else
:
else
:
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
n_bit
=
w_bit
,
get_scale_zp
=
True
,
**
q_config
)
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
n_bit
=
w_bit
,
get_scale_zp
=
True
,
**
q_config
)
scales
=
scales
.
t
().
contiguous
()
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
levels
=
name
.
split
(
'.'
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
if
len
(
levels
)
>
1
:
set_op_by_name
(
layer
,
name
,
q_linear
)
mod_
=
layer
for
l_idx
in
range
(
len
(
levels
)
-
1
):
if
levels
[
l_idx
].
isdigit
():
mod_
=
mod_
[
int
(
levels
[
l_idx
])]
else
:
mod_
=
getattr
(
mod_
,
levels
[
l_idx
])
setattr
(
mod_
,
levels
[
-
1
],
q_linear
)
else
:
setattr
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
awq/utils/lm_eval_adaptor.py
View file @
b0a0fecf
...
@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM):
...
@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM):
return
2048
return
2048
elif
'llama'
in
self
.
model_name
:
elif
'llama'
in
self
.
model_name
:
return
2048
# TODO: did not check this
return
2048
# TODO: did not check this
elif
'mpt'
in
self
.
model_name
:
return
2048
elif
'falcon'
in
self
.
model_name
:
return
2048
else
:
else
:
print
(
self
.
model
.
config
)
print
(
self
.
model
.
config
)
raise
NotImplementedError
raise
NotImplementedError
...
...
awq/utils/module.py
View file @
b0a0fecf
...
@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name):
...
@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name):
raise
ValueError
(
f
"Cannot find op
{
op_name
}
in module
{
module
}
"
)
raise
ValueError
(
f
"Cannot find op
{
op_name
}
in module
{
module
}
"
)
def
set_op_by_name
(
layer
,
name
,
new_module
):
levels
=
name
.
split
(
'.'
)
if
len
(
levels
)
>
1
:
mod_
=
layer
for
l_idx
in
range
(
len
(
levels
)
-
1
):
if
levels
[
l_idx
].
isdigit
():
mod_
=
mod_
[
int
(
levels
[
l_idx
])]
else
:
mod_
=
getattr
(
mod_
,
levels
[
l_idx
])
setattr
(
mod_
,
levels
[
-
1
],
new_module
)
else
:
setattr
(
layer
,
name
,
new_module
)
def
get_op_name
(
module
,
op
):
def
get_op_name
(
module
,
op
):
# get the name of the op relative to the module
# get the name of the op relative to the module
for
name
,
m
in
module
.
named_modules
():
for
name
,
m
in
module
.
named_modules
():
...
...
awq/utils/utils.py
0 → 100644
View file @
b0a0fecf
import
torch
import
accelerate
def
get_module_by_name_suffix
(
model
,
module_name
:
str
):
for
name
,
module
in
model
.
named_modules
():
if
name
.
endswith
(
module_name
):
return
module
def
simple_dispatch_model
(
model
,
device_map
):
from
accelerate.hooks
import
add_hook_to_module
,
AlignDevicesHook
if
""
in
device_map
:
d
=
device_map
[
""
]
model
=
model
.
to
(
torch
.
device
(
d
))
model
.
hf_device_map
=
device_map
return
model
tied_params
=
accelerate
.
utils
.
modeling
.
find_tied_parameters
(
model
)
if
set
(
device_map
.
values
())
==
{
"cpu"
}
or
set
(
device_map
.
values
())
==
{
"cpu"
,
"disk"
}:
main_device
=
"cpu"
else
:
main_device
=
[
d
for
d
in
device_map
.
values
()
if
d
not
in
[
"cpu"
,
"disk"
]][
0
]
cpu_offload_group
=
[(
n
,
d
)
for
n
,
d
in
device_map
.
items
()
if
d
==
"cpu"
]
prev_hook
=
None
for
idx
,
(
n
,
d
)
in
enumerate
(
cpu_offload_group
):
m
=
get_module_by_name_suffix
(
model
,
n
)
_
,
prev_hook
=
accelerate
.
cpu_offload_with_hook
(
m
,
execution_device
=
main_device
,
prev_module_hook
=
prev_hook
)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if
len
(
cpu_offload_group
)
>
1
:
get_module_by_name_suffix
(
model
,
cpu_offload_group
[
0
][
0
]).
_hf_hook
.
prev_module_hook
=
prev_hook
for
n
,
d
in
device_map
.
items
():
m
=
get_module_by_name_suffix
(
model
,
n
)
if
d
!=
"cpu"
:
d
=
torch
.
device
(
d
)
hook
=
AlignDevicesHook
(
d
,
io_same_device
=
True
,
place_submodules
=
True
)
add_hook_to_module
(
m
,
hook
)
accelerate
.
utils
.
modeling
.
retie_parameters
(
model
,
tied_params
)
model
.
hf_device_map
=
device_map
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment