Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
b0a0fecf
Unverified
Commit
b0a0fecf
authored
Jul 20, 2023
by
Jiaming Tang
Committed by
GitHub
Jul 20, 2023
Browse files
Merge pull request #41 from mit-han-lab/dev/more_models
parents
25e92c4c
ce4a6bb1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
619 additions
and
58 deletions
+619
-58
awq/entry.py
awq/entry.py
+68
-17
awq/kernels/gemm_cuda_gen.cu
awq/kernels/gemm_cuda_gen.cu
+232
-15
awq/kernels/setup.py
awq/kernels/setup.py
+2
-2
awq/quantize/auto_clip.py
awq/quantize/auto_clip.py
+6
-2
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+149
-5
awq/quantize/pre_quant.py
awq/quantize/pre_quant.py
+41
-2
awq/quantize/qmodule.py
awq/quantize/qmodule.py
+14
-0
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+46
-15
awq/utils/lm_eval_adaptor.py
awq/utils/lm_eval_adaptor.py
+4
-0
awq/utils/module.py
awq/utils/module.py
+14
-0
awq/utils/utils.py
awq/utils/utils.py
+43
-0
No files found.
awq/entry.py
View file @
b0a0fecf
from
lm_eval
import
evaluator
,
tasks
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
import
torch
import
argparse
import
os
import
json
from
accelerate
import
init_empty_weights
,
load_checkpoint_
and_dispatch
from
accelerate
import
init_empty_weights
,
infer_auto_device_map
,
dispatch_model
,
load_checkpoint_
in_model
from
awq.utils.parallel
import
auto_parallel
from
awq.quantize.pre_quant
import
run_awq
,
apply_awq
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.utils
import
simple_dispatch_model
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -20,6 +21,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser
.
add_argument
(
'--parallel'
,
action
=
'store_true'
,
help
=
"enable model parallelism"
)
# max memory to offload larger models to CPU
parser
.
add_argument
(
'--max_memory'
,
type
=
str
,
nargs
=
'*'
,
help
=
"List of device_id:max_memory pairs to be parsed into a dictionary; "
\
+
"Example: 0:10GiB 1:10GiB cpu:30GiB; "
\
+
"mode details here: "
\
+
"https://huggingface.co/docs/accelerate/usage_guides/big_modeling"
)
parser
.
add_argument
(
'--auto_parallel'
,
action
=
'store_true'
,
help
=
"automatically set parallel and batch_size"
)
# quantization config
...
...
@@ -43,6 +50,9 @@ parser.add_argument('--load_awq', type=str, default=None,
help
=
"load the awq search results"
)
args
=
parser
.
parse_args
()
max_memory
=
[
v
.
split
(
':'
)
for
v
in
(
args
.
max_memory
or
[])]
max_memory
=
{(
int
(
k
)
if
k
.
isdigit
()
else
k
):
v
for
k
,
v
in
max_memory
}
if
args
.
auto_parallel
:
gpu_list
=
auto_parallel
(
args
)
...
...
@@ -62,39 +72,67 @@ def build_model_and_enc(model_path):
print
(
f
"* Building model
{
model_path
}
"
)
# all hf model
config
=
AutoConfig
.
from_pretrained
(
model_path
)
enc
=
AutoTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
"mpt"
in
config
.
__class__
.
__name__
.
lower
():
enc
=
AutoTokenizer
.
from_pretrained
(
config
.
tokenizer_name
,
trust_remote_code
=
True
)
else
:
enc
=
AutoTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
,
trust_remote_code
=
True
)
if
args
.
load_quant
:
# directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
print
(
"Loading pre-computed quantized weights..."
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_
pretrained
(
model_path
,
config
=
config
,
torch_dtype
=
torch
.
float16
)
model
=
AutoModelForCausalLM
.
from_
config
(
config
=
config
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
real_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
=
load_checkpoint_and_dispatch
(
model
,
args
.
load_quant
,
device_map
=
"balanced"
,
# TODO: can we remove this?
model
.
tie_weights
()
# Infer device map
kwargs
=
{
"max_memory"
:
max_memory
}
if
len
(
max_memory
)
else
{}
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
]
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
else
:
# fp16 to quantized
kwargs
=
{
"device_map"
:
"balanced"
,
"torch_dtype"
:
torch
.
float16
}
# Load checkpoint in the model
load_checkpoint_in_model
(
model
,
checkpoint
=
args
.
load_quant
,
device_map
=
device_map
,
offload_state_dict
=
True
,
)
# Dispatch model
model
=
simple_dispatch_model
(
model
,
device_map
=
device_map
)
model
.
eval
()
else
:
# fp16 to quantized
args
.
run_awq
&=
not
args
.
load_awq
# if load_awq, no need to run awq
# Init model on CPU:
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"low_cpu_mem_usage"
:
True
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
**
kwargs
)
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model
.
eval
()
if
args
.
run_awq
:
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
awq_results
=
run_awq
(
model
,
enc
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
n_samples
=
128
,
seqlen
=
512
,
)
if
args
.
dump_awq
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_awq
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
awq_results
,
args
.
dump_awq
)
print
(
"AWQ results saved at"
,
args
.
dump_awq
)
exit
(
0
)
if
args
.
load_awq
:
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
...
...
@@ -113,12 +151,26 @@ def build_model_and_enc(model_path):
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
)
if
args
.
dump_quant
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_quant
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
print
(
f
"Saving the quantized model at
{
args
.
dump_quant
}
..."
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
args
.
dump_quant
)
exit
(
0
)
else
:
raise
NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs
=
{
"max_memory"
:
max_memory
}
if
len
(
max_memory
)
else
{}
device_map
=
infer_auto_device_map
(
model
,
# TODO: can we remove this?
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
model
=
dispatch_model
(
model
,
device_map
=
device_map
)
return
model
,
enc
...
...
@@ -136,11 +188,10 @@ def main():
# a hack here to auto set model group
model
,
enc
=
build_model_and_enc
(
args
.
model_path
)
lm_eval_model
=
LMEvalAdaptor
(
args
.
model_path
,
model
,
enc
,
args
.
batch_size
)
if
args
.
tasks
is
not
None
:
task_names
=
args
.
tasks
.
split
(
","
)
lm_eval_model
=
LMEvalAdaptor
(
args
.
model_path
,
model
,
enc
,
args
.
batch_size
)
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_model
,
tasks
=
task_names
,
...
...
awq/kernels/gemm_cuda_gen.cu
View file @
b0a0fecf
...
...
@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) {
return
(
v1
<<
16
)
|
v0
;
}
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
...
...
@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
__shared__
half
zeros_shared
[
128
];
int
j_factors1
=
((
OC
+
128
-
1
)
/
128
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
...
...
@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+
(((
int
)
threadIdx
.
x
)
/
(
128
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
...
...
@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
if
((
k_bound
-
1
)
*
32
+
blockIdx_z
>=
IC
)
k_bound
-=
1
;
if
((
k_bound
-
1
)
*
split_k_iters
*
32
+
blockIdx_z
*
32
>=
IC
)
k_bound
-=
1
;
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
__syncthreads
();
...
...
@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
128
*
(
OC
/
8
));
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
G
*
(
OC
/
8
));
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
128
*
(
OC
));
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
...
...
@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
8
;
++
ax0_ax1_fused_0
)
{
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
...
...
@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
}
}
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n64k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
B_shared
[
32
*
(
64
+
8
)];
__shared__
half
scaling_factors_shared
[
64
];
__shared__
half
zeros_shared
[
64
];
int
j_factors1
=
((
OC
+
64
-
1
)
/
64
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
half
A_shared_warp
[
8
];
half
B_shared_warp
[
16
];
for
(
int
j_0_4_init
=
0
;
j_0_4_init
<
2
;
++
j_0_4_init
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
C_warp
[(
j_0_4_init
*
8
)
+
i
]
=
0.0
;
}
}
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
64
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
64
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half
*
A_ptr
=
A
+
(((
int
)
blockIdx_y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
4
+
(((
int
)
threadIdx
.
x
)
/
(
64
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
)
)
*
8
;
half
*
B_shared_ptr
=
B_shared
+
((
int
)
threadIdx
.
y
)
*
(
row_stride
/
2
)
*
(
64
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
64
/
8
))
*
(
64
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
64
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
8
;
half
*
C_ptr
=
C
+
blockIdx_z
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
64
+
((
int
)
threadIdx
.
y
)
*
32
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
if
((
k_bound
-
1
)
*
split_k_iters
*
32
+
blockIdx_z
*
32
>=
IC
)
k_bound
-=
1
;
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
__syncthreads
();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if
(
ld_A_flag
)
{
*
(
uint4
*
)(
A_shared_ptr
)
=
*
(
uint4
*
)(
A_ptr
+
(
k_0_0
*
32
));
}
else
{
*
(
uint4
*
)(
A_shared_ptr
)
=
make_uint4
(
0
,
0
,
0
,
0
);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
G
*
(
OC
/
8
));
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
4
;
++
ax0_ax1_fused_0
)
{
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*
(
uint4
*
)(
B_shared_ptr
+
ax0_ax1_fused_0
*
row_stride
*
(
64
+
8
))
=
B_loaded_fp16
;
}
__syncthreads
();
for
(
int
k_0_1
=
0
;
k_0_1
<
2
;
++
k_0_1
)
{
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
:
"r"
(
addr
)
);
}
for
(
int
ax1_0
=
0
;
ax1_0
<
2
;
++
ax1_0
)
{
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
1152
)
+
(((
int
)
threadIdx
.
y
)
*
32
))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
72
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
:
"r"
(
addr
)
);
}
}
for
(
int
j_0_4
=
0
;
j_0_4
<
2
;
++
j_0_4
)
{
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for
(
int
ax1_0_1
=
0
;
ax1_0_1
<
2
;
++
ax1_0_1
)
{
for
(
int
local_id
=
0
;
local_id
<
8
;
++
local_id
)
{
int
row_offset
=
(((
int
)
blockIdx_y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
}
}
}
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
...
...
@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda(
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
if
(
num_out_channels
%
128
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 128"
);
if
(
num_out_channels
%
64
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
if
(
num_out_channels
%
8
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
int
j_factors1
=
num_out_channels
/
128
/
1
;
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
if
(
group_size
%
32
!=
0
)
throw
std
::
invalid_argument
(
"Group size should be a multiple of 32"
);
if
(
num_out_channels
%
group_size
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of Group size"
);
if
(
num_out_channels
%
128
==
0
)
{
int
j_factors1
=
num_out_channels
/
128
/
1
;
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
{
int
j_factors1
=
num_out_channels
/
64
/
1
;
dim3
num_blocks
(
1
*
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
}
awq/kernels/setup.py
View file @
b0a0fecf
...
...
@@ -3,7 +3,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtensio
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
,
"-keep"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
],
}
setup
(
...
...
@@ -18,4 +18,4 @@ setup(
],
cmdclass
=
{
"build_ext"
:
BuildExtension
},
install_requires
=
[
"torch"
],
)
\ No newline at end of file
)
awq/quantize/auto_clip.py
View file @
b0a0fecf
...
...
@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
# prevent OOM
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
best_max_val_all
=
[]
...
...
@@ -73,11 +73,13 @@ def auto_clip_block(module,
clip_list
=
[]
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
[
"q_"
,
"k_"
]]):
if
any
([
_
in
name
for
_
in
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]]):
continue
named_linears
[
name
].
cuda
()
max_val
=
auto_clip_layer
(
named_linears
[
name
].
weight
,
input_feat
[
name
],
n_bit
=
w_bit
,
q_config
=
q_config
)
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
return
clip_list
...
...
@@ -86,8 +88,10 @@ def apply_clip(module, clip_list):
from
..utils.module
import
get_op_by_name
for
name
,
max_val
in
clip_list
:
layer
=
get_op_by_name
(
module
,
name
)
layer
.
cuda
()
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
org_shape
=
layer
.
weight
.
shape
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
layer
.
weight
.
data
=
torch
.
clamp
(
layer
.
weight
.
data
,
-
max_val
,
max_val
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
org_shape
)
layer
.
cpu
()
awq/quantize/auto_scale.py
View file @
b0a0fecf
import
gc
import
torch
import
torch.nn
as
nn
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
..utils.module
import
get_op_by_name
,
get_op_name
from
.qmodule
import
ScaledActivation
from
..utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
...
...
@@ -32,6 +35,13 @@ def scale_ln_fcs(ln, fcs, scales):
scales
=
scales
.
to
(
ln
.
weight
.
device
)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
...
...
@@ -50,11 +60,12 @@ def scale_ln_fcs(ln, fcs, scales):
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
assert
fc1
.
out_features
==
fc2
.
in_features
#
assert fc1.out_features == fc2.in_features
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
fc1
.
weight
.
div_
(
scales
.
view
(
-
1
,
1
))
# fc1.weight.div_(scales.view(-1, 1))
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
...
...
@@ -66,6 +77,17 @@ def scale_fc_fc(fc1, fc2, scales):
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
isinstance
(
gelu
,
nn
.
GELU
)
or
isinstance
(
gelu
,
BloomGelu
)
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
auto_scale_block
(
module
,
module_kwargs
,
w_bit
,
q_config
,
...
...
@@ -86,11 +108,15 @@ def auto_scale_block(module, module_kwargs,
def
_search_module_scale
(
block
,
linears2scale
:
list
,
x
,
kwargs
=
{}):
# w: co, ci
# x: n, ci
x
=
x
.
to
(
next
(
block
.
parameters
()).
device
)
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
w_max
=
get_weight_scale
(
weight
,
q_group_size
=
q_config
.
get
(
"q_group_size"
,
-
1
))
# Clear GPU memory
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
x
=
x
.
to
(
next
(
block
.
parameters
()).
device
)
with
torch
.
no_grad
():
org_out
=
block
(
x
,
**
kwargs
)
if
isinstance
(
org_out
,
tuple
):
...
...
@@ -112,7 +138,7 @@ def auto_scale_block(module, module_kwargs,
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
for
fc
in
linears2scale
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
)
.
to
(
fc
.
weight
.
device
)
)
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
out
=
block
(
x
,
**
kwargs
)
...
...
@@ -143,6 +169,7 @@ def auto_scale_block(module, module_kwargs,
module2inspect
=
layers
[
0
]
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
scales
.
detach
().
cpu
()
# prev_op_name, [layer_name], scale
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
...
...
@@ -204,7 +231,110 @@ def auto_scale_block(module, module_kwargs,
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
'mlp.down_proj'
],
))
elif
isinstance
(
module
,
BloomBlock
):
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
gelu_impl
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
elif
"mpt"
in
str
(
module
.
__class__
).
lower
():
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
norm_1
,
layers
=
[
module
.
attn
.
Wqkv
],
inp
=
input_feat
[
'attn.Wqkv'
],
module2inspect
=
module
.
attn
,
kwargs
=
module_kwargs
,
))
# attn out
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
attn
.
Wqkv
,
layers
=
[
module
.
attn
.
out_proj
],
inp
=
input_feat
[
'attn.out_proj'
],
))
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
norm_2
,
layers
=
[
module
.
ffn
.
up_proj
],
inp
=
input_feat
[
'ffn.up_proj'
],
module2inspect
=
module
.
ffn
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ffn
.
act
,
layers
=
[
module
.
ffn
.
down_proj
],
inp
=
input_feat
[
'ffn.down_proj'
],
))
elif
"falcon"
in
str
(
module
.
__class__
).
lower
():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
if
"falcon-7b"
in
str
(
module
.
__class__
).
lower
():
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
,
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
elif
"falcon-40b"
in
str
(
module
.
__class__
).
lower
():
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ln_attn
,
layers
=
[
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ln_mlp
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
else
:
raise
NotImplementedError
(
"Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported"
)
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
act
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
else
:
raise
NotImplementedError
(
f
"
{
type
(
module
)
}
not supported yet!"
)
...
...
@@ -214,12 +344,21 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for
prev_op_name
,
layer_names
,
scales
in
scales_list
:
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
prev_op
.
cuda
()
for
layer
in
layers
:
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
isinstance
(
prev_op
,
nn
.
GELU
)
or
isinstance
(
prev_op
,
BloomGelu
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
else
:
raise
NotImplementedError
(
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
...
...
@@ -229,3 +368,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for
layer_name
in
layer_names
:
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
prev_op
.
cpu
()
for
layer
in
layers
:
layer
.
cpu
()
scales
.
cpu
()
awq/quantize/pre_quant.py
View file @
b0a0fecf
...
...
@@ -5,6 +5,7 @@ import gc
import
functools
from
collections
import
defaultdict
from
transformers.models.bloom.modeling_bloom
import
BloomForCausalLM
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
...
...
@@ -23,10 +24,32 @@ def get_blocks(model):
layers
=
model
.
model
.
layers
elif
isinstance
(
model
,
OPTForCausalLM
):
layers
=
model
.
model
.
decoder
.
layers
elif
isinstance
(
model
,
BloomForCausalLM
):
layers
=
model
.
transformer
.
h
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
blocks
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
h
else
:
raise
NotImplementedError
(
type
(
model
))
return
layers
def
move_embed
(
model
,
device
):
if
isinstance
(
model
,
LlamaForCausalLM
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
elif
isinstance
(
model
,
OPTForCausalLM
):
model
.
model
.
decoder
.
embed_tokens
=
model
.
model
.
decoder
.
embed_tokens
.
to
(
device
)
model
.
model
.
decoder
.
embed_positions
=
model
.
model
.
decoder
.
embed_positions
.
to
(
device
)
elif
isinstance
(
model
,
BloomForCausalLM
):
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
model
.
transformer
.
word_embeddings_layernorm
=
model
.
transformer
.
word_embeddings_layernorm
.
to
(
device
)
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
else
:
raise
NotImplementedError
(
type
(
model
))
@
torch
.
no_grad
()
def
run_awq
(
...
...
@@ -50,6 +73,9 @@ def run_awq(
inps
=
[]
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
move_embed
(
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
...
...
@@ -69,9 +95,13 @@ def run_awq(
model
(
samples
.
to
(
next
(
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
move_embed
(
model
,
"cpu"
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -83,6 +113,7 @@ def run_awq(
# solve layer by layer
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"Running AWQ..."
):
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
# firstly, get input features of all linear layers
...
...
@@ -102,19 +133,25 @@ def run_awq(
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
scales_list
=
auto_scale_block
(
layer
,
layer_kwargs
,
w_bit
=
w_bit
,
q_config
=
q_config
,
input_feat
=
input_feat
,
)
apply_scale
(
layer
,
scales_list
,
input_feat_dict
=
input_feat
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
model
,
layer
)
+
"."
)
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
mse_range
:
clip_list
=
auto_clip_block
(
layer
,
...
...
@@ -124,6 +161,8 @@ def run_awq(
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
del
input_feat
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
...
...
awq/quantize/qmodule.py
View file @
b0a0fecf
...
...
@@ -4,6 +4,16 @@ import torch.nn as nn
import
f16s4_gemm
# with CUDA kernels
class
ScaledActivation
(
nn
.
Module
):
def
__init__
(
self
,
module
,
scales
):
super
().
__init__
()
self
.
act
=
module
self
.
scales
=
nn
.
Parameter
(
scales
.
data
)
def
forward
(
self
,
x
):
return
self
.
act
(
x
)
/
self
.
scales
.
view
(
1
,
1
,
-
1
).
to
(
x
.
device
)
class
WQLinear
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
...
...
@@ -83,3 +93,7 @@ class WQLinear(nn.Module):
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
def
extra_repr
(
self
)
->
str
:
return
'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
,
self
.
w_bit
,
self
.
group_size
)
awq/quantize/quantizer.py
View file @
b0a0fecf
...
...
@@ -2,11 +2,48 @@ import torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
import
gc
from
.qmodule
import
ScaledActivation
from
..utils.module
import
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
EMBEDDING_KEYWORDS
=
[
"embed"
]
LM_HEAD_KEYWORDS
=
[
"lm_head"
,
"embed_out"
,
"output"
]
def
scale_activations
(
module
):
param
=
next
(
module
.
parameters
())
dtype
=
param
.
dtype
device
=
param
.
device
if
isinstance
(
module
,
BloomBlock
):
if
isinstance
(
module
.
mlp
.
gelu_impl
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
gelu_impl
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.gelu_impl"
,
act
)
elif
'mptblock'
in
str
(
module
.
__class__
.
__name__
).
lower
():
if
isinstance
(
module
.
ffn
.
act
,
ScaledActivation
):
return
c
=
module
.
ffn
.
up_proj
.
out_features
act
=
ScaledActivation
(
module
.
ffn
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"ffn.act"
,
act
)
elif
'falcon'
in
str
(
module
.
__class__
).
lower
():
if
isinstance
(
module
.
mlp
.
act
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.act"
,
act
)
# core quantization method (simulated quantization)
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
zero_point
=
True
,
q_group_size
=-
1
,
...
...
@@ -61,7 +98,9 @@ def pseudo_quantize_model_weight(
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"pseudo weight quantization..."
):
named_linears
=
get_named_linears
(
layers
[
i
])
for
n
,
m
in
named_linears
.
items
():
m
.
cuda
()
m
.
weight
.
data
=
pseudo_quantize_tensor
(
m
.
weight
.
data
,
n_bit
=
w_bit
,
**
q_config
)
m
.
cpu
()
@
torch
.
no_grad
()
...
...
@@ -77,29 +116,21 @@ def real_quantize_model_weight(
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
+
(
"(init only)"
if
init_only
else
""
)):
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
scale_activations
(
layer
)
for
name
,
module
in
named_linears
.
items
():
if
init_only
:
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
else
:
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
n_bit
=
w_bit
,
get_scale_zp
=
True
,
**
q_config
)
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
levels
=
name
.
split
(
'.'
)
if
len
(
levels
)
>
1
:
mod_
=
layer
for
l_idx
in
range
(
len
(
levels
)
-
1
):
if
levels
[
l_idx
].
isdigit
():
mod_
=
mod_
[
int
(
levels
[
l_idx
])]
else
:
mod_
=
getattr
(
mod_
,
levels
[
l_idx
])
setattr
(
mod_
,
levels
[
-
1
],
q_linear
)
else
:
setattr
(
layer
,
name
,
q_linear
)
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
\ No newline at end of file
gc
.
collect
()
awq/utils/lm_eval_adaptor.py
View file @
b0a0fecf
...
...
@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM):
return
2048
elif
'llama'
in
self
.
model_name
:
return
2048
# TODO: did not check this
elif
'mpt'
in
self
.
model_name
:
return
2048
elif
'falcon'
in
self
.
model_name
:
return
2048
else
:
print
(
self
.
model
.
config
)
raise
NotImplementedError
...
...
awq/utils/module.py
View file @
b0a0fecf
...
...
@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name):
raise
ValueError
(
f
"Cannot find op
{
op_name
}
in module
{
module
}
"
)
def
set_op_by_name
(
layer
,
name
,
new_module
):
levels
=
name
.
split
(
'.'
)
if
len
(
levels
)
>
1
:
mod_
=
layer
for
l_idx
in
range
(
len
(
levels
)
-
1
):
if
levels
[
l_idx
].
isdigit
():
mod_
=
mod_
[
int
(
levels
[
l_idx
])]
else
:
mod_
=
getattr
(
mod_
,
levels
[
l_idx
])
setattr
(
mod_
,
levels
[
-
1
],
new_module
)
else
:
setattr
(
layer
,
name
,
new_module
)
def
get_op_name
(
module
,
op
):
# get the name of the op relative to the module
for
name
,
m
in
module
.
named_modules
():
...
...
awq/utils/utils.py
0 → 100644
View file @
b0a0fecf
import
torch
import
accelerate
def
get_module_by_name_suffix
(
model
,
module_name
:
str
):
for
name
,
module
in
model
.
named_modules
():
if
name
.
endswith
(
module_name
):
return
module
def
simple_dispatch_model
(
model
,
device_map
):
from
accelerate.hooks
import
add_hook_to_module
,
AlignDevicesHook
if
""
in
device_map
:
d
=
device_map
[
""
]
model
=
model
.
to
(
torch
.
device
(
d
))
model
.
hf_device_map
=
device_map
return
model
tied_params
=
accelerate
.
utils
.
modeling
.
find_tied_parameters
(
model
)
if
set
(
device_map
.
values
())
==
{
"cpu"
}
or
set
(
device_map
.
values
())
==
{
"cpu"
,
"disk"
}:
main_device
=
"cpu"
else
:
main_device
=
[
d
for
d
in
device_map
.
values
()
if
d
not
in
[
"cpu"
,
"disk"
]][
0
]
cpu_offload_group
=
[(
n
,
d
)
for
n
,
d
in
device_map
.
items
()
if
d
==
"cpu"
]
prev_hook
=
None
for
idx
,
(
n
,
d
)
in
enumerate
(
cpu_offload_group
):
m
=
get_module_by_name_suffix
(
model
,
n
)
_
,
prev_hook
=
accelerate
.
cpu_offload_with_hook
(
m
,
execution_device
=
main_device
,
prev_module_hook
=
prev_hook
)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if
len
(
cpu_offload_group
)
>
1
:
get_module_by_name_suffix
(
model
,
cpu_offload_group
[
0
][
0
]).
_hf_hook
.
prev_module_hook
=
prev_hook
for
n
,
d
in
device_map
.
items
():
m
=
get_module_by_name_suffix
(
model
,
n
)
if
d
!=
"cpu"
:
d
=
torch
.
device
(
d
)
hook
=
AlignDevicesHook
(
d
,
io_same_device
=
True
,
place_submodules
=
True
)
add_hook_to_module
(
m
,
hook
)
accelerate
.
utils
.
modeling
.
retie_parameters
(
model
,
tied_params
)
model
.
hf_device_map
=
device_map
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment