Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
71d8e68d
"csrc/git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "014c4baefa0ba920799e58ceaecdc0f22c0e006e"
Commit
71d8e68d
authored
Jun 25, 2023
by
Haotian Tang
Browse files
[Major] Add support for BLOOM, MPT and Falcon.
parent
06e299ba
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
433 additions
and
43 deletions
+433
-43
awq/entry.py
awq/entry.py
+7
-5
awq/kernels/gemm_cuda_gen.cu
awq/kernels/gemm_cuda_gen.cu
+232
-15
awq/quantize/auto_clip.py
awq/quantize/auto_clip.py
+1
-1
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+114
-5
awq/quantize/pre_quant.py
awq/quantize/pre_quant.py
+10
-2
awq/quantize/qmodule.py
awq/quantize/qmodule.py
+10
-0
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+41
-15
awq/utils/lm_eval_adaptor.py
awq/utils/lm_eval_adaptor.py
+4
-0
awq/utils/module.py
awq/utils/module.py
+14
-0
No files found.
awq/entry.py
View file @
71d8e68d
...
@@ -62,15 +62,18 @@ def build_model_and_enc(model_path):
...
@@ -62,15 +62,18 @@ def build_model_and_enc(model_path):
print
(
f
"* Building model
{
model_path
}
"
)
print
(
f
"* Building model
{
model_path
}
"
)
# all hf model
# all hf model
config
=
AutoConfig
.
from_pretrained
(
model_path
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
enc
=
AutoTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
)
if
"mpt"
in
config
.
__class__
.
__name__
.
lower
():
enc
=
AutoTokenizer
.
from_pretrained
(
config
.
tokenizer_name
)
else
:
enc
=
AutoTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
)
if
args
.
load_quant
:
# directly load quantized weights
if
args
.
load_quant
:
# directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
# no need to really load the fp16 weights... just to get the model structure
print
(
"Loading pre-computed quantized weights..."
)
print
(
"Loading pre-computed quantized weights..."
)
with
init_empty_weights
():
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
torch_dtype
=
torch
.
float16
)
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
real_quantize_model_weight
(
real_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
=
load_checkpoint_and_dispatch
(
model
=
load_checkpoint_and_dispatch
(
...
@@ -83,8 +86,7 @@ def build_model_and_enc(model_path):
...
@@ -83,8 +86,7 @@ def build_model_and_enc(model_path):
kwargs
=
{
"device_map"
:
"balanced"
,
"torch_dtype"
:
torch
.
float16
}
kwargs
=
{
"device_map"
:
"balanced"
,
"torch_dtype"
:
torch
.
float16
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
**
kwargs
)
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
if
args
.
run_awq
:
if
args
.
run_awq
:
awq_results
=
run_awq
(
awq_results
=
run_awq
(
model
,
enc
,
model
,
enc
,
...
...
awq/kernels/gemm_cuda_gen.cu
View file @
71d8e68d
...
@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) {
...
@@ -13,7 +13,7 @@ __pack_half2(const half x, const half y) {
return
(
v1
<<
16
)
|
v0
;
return
(
v1
<<
16
)
|
v0
;
}
}
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
{
static
constexpr
uint32_t
ZERO
=
0x0
;
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
float
C_warp
[
32
];
...
@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -24,7 +24,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
__shared__
half
zeros_shared
[
128
];
__shared__
half
zeros_shared
[
128
];
int
j_factors1
=
((
OC
+
128
-
1
)
/
128
);
int
j_factors1
=
((
OC
+
128
-
1
)
/
128
);
int
blockIdx_x
=
0
;
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
...
@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -53,6 +52,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+
(((
int
)
threadIdx
.
x
)
/
(
128
/
8
))
*
(
OC
/
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
128
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
1
;
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
...
@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -80,7 +80,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// preload s.f. and zeros
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
if
((
k_bound
-
1
)
*
32
+
blockIdx_z
>=
IC
)
k_bound
-=
1
;
if
((
k_bound
-
1
)
*
split_k_iters
*
32
+
blockIdx_z
*
32
>=
IC
)
k_bound
-=
1
;
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
__syncthreads
();
__syncthreads
();
...
@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -95,9 +95,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
}
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
128
*
(
OC
/
8
));
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
G
*
(
OC
/
8
));
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
128
*
(
OC
));
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
/*
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
...
@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -107,6 +107,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
8
;
++
ax0_ax1_fused_0
)
{
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
8
;
++
ax0_ax1_fused_0
)
{
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each warp: 32 x 4
...
@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -205,6 +206,204 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
}
}
}
}
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n64k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
B_shared
[
32
*
(
64
+
8
)];
__shared__
half
scaling_factors_shared
[
64
];
__shared__
half
zeros_shared
[
64
];
int
j_factors1
=
((
OC
+
64
-
1
)
/
64
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
half
A_shared_warp
[
8
];
half
B_shared_warp
[
16
];
for
(
int
j_0_4_init
=
0
;
j_0_4_init
<
2
;
++
j_0_4_init
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
C_warp
[(
j_0_4_init
*
8
)
+
i
]
=
0.0
;
}
}
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
64
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
64
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half
*
A_ptr
=
A
+
(((
int
)
blockIdx_y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
4
+
(((
int
)
threadIdx
.
x
)
/
(
64
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
)
)
*
8
;
half
*
B_shared_ptr
=
B_shared
+
((
int
)
threadIdx
.
y
)
*
(
row_stride
/
2
)
*
(
64
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
64
/
8
))
*
(
64
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
64
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
64
)
+
(((
int
)
threadIdx
.
x
)
%
(
64
/
8
))
*
8
;
half
*
C_ptr
=
C
+
blockIdx_z
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
64
+
((
int
)
threadIdx
.
y
)
*
32
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
if
((
k_bound
-
1
)
*
split_k_iters
*
32
+
blockIdx_z
*
32
>=
IC
)
k_bound
-=
1
;
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
__syncthreads
();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if
(
ld_A_flag
)
{
*
(
uint4
*
)(
A_shared_ptr
)
=
*
(
uint4
*
)(
A_ptr
+
(
k_0_0
*
32
));
}
else
{
*
(
uint4
*
)(
A_shared_ptr
)
=
make_uint4
(
0
,
0
,
0
,
0
);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
G
*
(
OC
/
8
));
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
4
;
++
ax0_ax1_fused_0
)
{
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*
(
uint4
*
)(
B_shared_ptr
+
ax0_ax1_fused_0
*
row_stride
*
(
64
+
8
))
=
B_loaded_fp16
;
}
__syncthreads
();
for
(
int
k_0_1
=
0
;
k_0_1
<
2
;
++
k_0_1
)
{
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
:
"r"
(
addr
)
);
}
for
(
int
ax1_0
=
0
;
ax1_0
<
2
;
++
ax1_0
)
{
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
1152
)
+
(((
int
)
threadIdx
.
y
)
*
32
))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
72
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
:
"r"
(
addr
)
);
}
}
for
(
int
j_0_4
=
0
;
j_0_4
<
2
;
++
j_0_4
)
{
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for
(
int
ax1_0_1
=
0
;
ax1_0_1
<
2
;
++
ax1_0_1
)
{
for
(
int
local_id
=
0
;
local_id
<
8
;
++
local_id
)
{
int
row_offset
=
(((
int
)
blockIdx_y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
}
}
}
}
// in_feats: M, IC [float16]
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// scaling_factors: IC // G, OC [float16]
...
@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda(
...
@@ -232,20 +431,38 @@ torch::Tensor gemm_forward_cuda(
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
if
(
num_out_channels
%
64
!=
0
)
if
(
num_out_channels
%
128
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 128"
);
if
(
num_out_channels
%
8
!=
0
)
if
(
num_out_channels
%
8
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
int
j_factors1
=
num_out_channels
/
128
/
1
;
if
(
group_size
%
32
!=
0
)
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
throw
std
::
invalid_argument
(
"Group size should be a multiple of 32"
);
if
(
num_out_channels
%
group_size
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of Group size"
);
if
(
num_out_channels
%
128
==
0
)
{
int
j_factors1
=
num_out_channels
/
128
/
1
;
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
{
int
j_factors1
=
num_out_channels
/
64
/
1
;
dim3
num_blocks
(
1
*
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
return
_out_feats
.
sum
(
0
);
}
}
awq/quantize/auto_clip.py
View file @
71d8e68d
...
@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
...
@@ -22,7 +22,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
# prevent OOM
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
w_all
=
w
best_max_val_all
=
[]
best_max_val_all
=
[]
...
...
awq/quantize/auto_scale.py
View file @
71d8e68d
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
..utils.module
import
get_op_by_name
,
get_op_name
from
.qmodule
import
ScaledActivation
from
..utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
...
@@ -32,6 +34,13 @@ def scale_ln_fcs(ln, fcs, scales):
...
@@ -32,6 +34,13 @@ def scale_ln_fcs(ln, fcs, scales):
scales
=
scales
.
to
(
ln
.
weight
.
device
)
scales
=
scales
.
to
(
ln
.
weight
.
device
)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln
.
weight
.
div_
(
scales
)
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
ln
.
bias
.
div_
(
scales
)
...
@@ -50,11 +59,12 @@ def scale_ln_fcs(ln, fcs, scales):
...
@@ -50,11 +59,12 @@ def scale_ln_fcs(ln, fcs, scales):
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
assert
fc1
.
out_features
==
fc2
.
in_features
#
assert fc1.out_features == fc2.in_features
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
fc1
.
weight
.
div_
(
scales
.
view
(
-
1
,
1
))
# fc1.weight.div_(scales.view(-1, 1))
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
...
@@ -66,6 +76,17 @@ def scale_fc_fc(fc1, fc2, scales):
...
@@ -66,6 +76,17 @@ def scale_fc_fc(fc1, fc2, scales):
assert
torch
.
isnan
(
p
).
sum
()
==
0
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
isinstance
(
gelu
,
nn
.
GELU
)
or
isinstance
(
gelu
,
BloomGelu
)
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
auto_scale_block
(
module
,
module_kwargs
,
def
auto_scale_block
(
module
,
module_kwargs
,
w_bit
,
q_config
,
w_bit
,
q_config
,
...
@@ -112,7 +133,7 @@ def auto_scale_block(module, module_kwargs,
...
@@ -112,7 +133,7 @@ def auto_scale_block(module, module_kwargs,
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
for
fc
in
linears2scale
:
for
fc
in
linears2scale
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
)
.
to
(
fc
.
weight
.
device
)
)
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
out
=
block
(
x
,
**
kwargs
)
out
=
block
(
x
,
**
kwargs
)
...
@@ -204,7 +225,91 @@ def auto_scale_block(module, module_kwargs,
...
@@ -204,7 +225,91 @@ def auto_scale_block(module, module_kwargs,
layers
=
[
module
.
mlp
.
down_proj
],
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
'mlp.down_proj'
],
inp
=
input_feat
[
'mlp.down_proj'
],
))
))
elif
isinstance
(
module
,
BloomBlock
):
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# attn out
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attn.dense'],
))
"""
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
gelu_impl
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
elif
"mpt"
in
str
(
module
.
__class__
).
lower
():
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
norm_1
,
layers
=
[
module
.
attn
.
Wqkv
],
inp
=
input_feat
[
'attn.Wqkv'
],
module2inspect
=
module
.
attn
,
kwargs
=
module_kwargs
,
))
# attn out
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
attn
.
Wqkv
,
layers
=
[
module
.
attn
.
out_proj
],
inp
=
input_feat
[
'attn.out_proj'
],
))
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
norm_2
,
layers
=
[
module
.
ffn
.
up_proj
],
inp
=
input_feat
[
'ffn.up_proj'
],
module2inspect
=
module
.
ffn
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ffn
.
act
,
layers
=
[
module
.
ffn
.
down_proj
],
inp
=
input_feat
[
'ffn.down_proj'
],
))
elif
"falcon"
in
str
(
module
.
__class__
).
lower
():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
,
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
act
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
else
:
else
:
raise
NotImplementedError
(
f
"
{
type
(
module
)
}
not supported yet!"
)
raise
NotImplementedError
(
f
"
{
type
(
module
)
}
not supported yet!"
)
...
@@ -220,6 +325,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -220,6 +325,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
isinstance
(
prev_op
,
nn
.
GELU
)
or
isinstance
(
prev_op
,
BloomGelu
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
...
@@ -228,4 +337,4 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -228,4 +337,4 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if
input_feat_dict
is
not
None
:
if
input_feat_dict
is
not
None
:
for
layer_name
in
layer_names
:
for
layer_name
in
layer_names
:
inp
=
input_feat_dict
[
layer_name
]
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
\ No newline at end of file
awq/quantize/pre_quant.py
View file @
71d8e68d
...
@@ -5,6 +5,7 @@ import gc
...
@@ -5,6 +5,7 @@ import gc
import
functools
import
functools
from
collections
import
defaultdict
from
collections
import
defaultdict
from
transformers.models.bloom.modeling_bloom
import
BloomForCausalLM
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
...
@@ -23,6 +24,12 @@ def get_blocks(model):
...
@@ -23,6 +24,12 @@ def get_blocks(model):
layers
=
model
.
model
.
layers
layers
=
model
.
model
.
layers
elif
isinstance
(
model
,
OPTForCausalLM
):
elif
isinstance
(
model
,
OPTForCausalLM
):
layers
=
model
.
model
.
decoder
.
layers
layers
=
model
.
model
.
decoder
.
layers
elif
isinstance
(
model
,
BloomForCausalLM
):
layers
=
model
.
transformer
.
h
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
blocks
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
h
else
:
else
:
raise
NotImplementedError
(
type
(
model
))
raise
NotImplementedError
(
type
(
model
))
return
layers
return
layers
...
@@ -102,7 +109,6 @@ def run_awq(
...
@@ -102,7 +109,6 @@ def run_awq(
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
for
h
in
handles
:
for
h
in
handles
:
h
.
remove
()
h
.
remove
()
# now solve for scaling and clipping
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
...
@@ -112,7 +118,8 @@ def run_awq(
...
@@ -112,7 +118,8 @@ def run_awq(
w_bit
=
w_bit
,
q_config
=
q_config
,
w_bit
=
w_bit
,
q_config
=
q_config
,
input_feat
=
input_feat
,
input_feat
=
input_feat
,
)
)
apply_scale
(
layer
,
scales_list
,
input_feat_dict
=
input_feat
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
model
,
layer
)
+
"."
)
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
model
,
layer
)
+
"."
)
...
@@ -124,6 +131,7 @@ def run_awq(
...
@@ -124,6 +131,7 @@ def run_awq(
# append prefix to make names global
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
# Haotian: check activation replacement
del
input_feat
del
input_feat
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
awq/quantize/qmodule.py
View file @
71d8e68d
...
@@ -4,6 +4,16 @@ import torch.nn as nn
...
@@ -4,6 +4,16 @@ import torch.nn as nn
import
f16s4_gemm
# with CUDA kernels
import
f16s4_gemm
# with CUDA kernels
class
ScaledActivation
(
nn
.
Module
):
def
__init__
(
self
,
module
,
scales
):
super
().
__init__
()
self
.
act
=
module
self
.
scales
=
nn
.
Parameter
(
scales
.
data
)
def
forward
(
self
,
x
):
return
self
.
act
(
x
)
/
self
.
scales
.
view
(
1
,
1
,
-
1
).
to
(
x
.
device
)
class
WQLinear
(
nn
.
Module
):
class
WQLinear
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
super
().
__init__
()
...
...
awq/quantize/quantizer.py
View file @
71d8e68d
...
@@ -2,11 +2,48 @@ import torch
...
@@ -2,11 +2,48 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
gc
import
gc
from
.qmodule
import
ScaledActivation
from
..utils.module
import
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
EMBEDDING_KEYWORDS
=
[
"embed"
]
EMBEDDING_KEYWORDS
=
[
"embed"
]
LM_HEAD_KEYWORDS
=
[
"lm_head"
,
"embed_out"
,
"output"
]
LM_HEAD_KEYWORDS
=
[
"lm_head"
,
"embed_out"
,
"output"
]
def
scale_activations
(
module
):
param
=
next
(
module
.
parameters
())
dtype
=
param
.
dtype
device
=
param
.
device
if
isinstance
(
module
,
BloomBlock
):
if
isinstance
(
module
.
mlp
.
gelu_impl
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
gelu_impl
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.gelu_impl"
,
act
)
elif
'mptblock'
in
str
(
module
.
__class__
.
__name__
).
lower
():
if
isinstance
(
module
.
ffn
.
act
,
ScaledActivation
):
return
c
=
module
.
ffn
.
up_proj
.
out_features
act
=
ScaledActivation
(
module
.
ffn
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"ffn.act"
,
act
)
elif
'falcon'
in
str
(
module
.
__class__
).
lower
():
if
isinstance
(
module
.
mlp
.
act
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.act"
,
act
)
# core quantization method (simulated quantization)
# core quantization method (simulated quantization)
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
zero_point
=
True
,
q_group_size
=-
1
,
zero_point
=
True
,
q_group_size
=-
1
,
...
@@ -77,7 +114,8 @@ def real_quantize_model_weight(
...
@@ -77,7 +114,8 @@ def real_quantize_model_weight(
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
+
(
"(init only)"
if
init_only
else
""
)):
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
+
(
"(init only)"
if
init_only
else
""
)):
layer
=
layers
[
i
]
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
named_linears
=
get_named_linears
(
layer
)
scale_activations
(
layer
)
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
if
init_only
:
if
init_only
:
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
...
@@ -88,18 +126,6 @@ def real_quantize_model_weight(
...
@@ -88,18 +126,6 @@ def real_quantize_model_weight(
zeros
=
zeros
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
set_op_by_name
(
layer
,
name
,
q_linear
)
levels
=
name
.
split
(
'.'
)
if
len
(
levels
)
>
1
:
mod_
=
layer
for
l_idx
in
range
(
len
(
levels
)
-
1
):
if
levels
[
l_idx
].
isdigit
():
mod_
=
mod_
[
int
(
levels
[
l_idx
])]
else
:
mod_
=
getattr
(
mod_
,
levels
[
l_idx
])
setattr
(
mod_
,
levels
[
-
1
],
q_linear
)
else
:
setattr
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
\ No newline at end of file
awq/utils/lm_eval_adaptor.py
View file @
71d8e68d
...
@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM):
...
@@ -47,6 +47,10 @@ class LMEvalAdaptor(BaseLM):
return
2048
return
2048
elif
'llama'
in
self
.
model_name
:
elif
'llama'
in
self
.
model_name
:
return
2048
# TODO: did not check this
return
2048
# TODO: did not check this
elif
'mpt'
in
self
.
model_name
:
return
2048
elif
'falcon'
in
self
.
model_name
:
return
2048
else
:
else
:
print
(
self
.
model
.
config
)
print
(
self
.
model
.
config
)
raise
NotImplementedError
raise
NotImplementedError
...
...
awq/utils/module.py
View file @
71d8e68d
...
@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name):
...
@@ -8,6 +8,20 @@ def get_op_by_name(module, op_name):
raise
ValueError
(
f
"Cannot find op
{
op_name
}
in module
{
module
}
"
)
raise
ValueError
(
f
"Cannot find op
{
op_name
}
in module
{
module
}
"
)
def
set_op_by_name
(
layer
,
name
,
new_module
):
levels
=
name
.
split
(
'.'
)
if
len
(
levels
)
>
1
:
mod_
=
layer
for
l_idx
in
range
(
len
(
levels
)
-
1
):
if
levels
[
l_idx
].
isdigit
():
mod_
=
mod_
[
int
(
levels
[
l_idx
])]
else
:
mod_
=
getattr
(
mod_
,
levels
[
l_idx
])
setattr
(
mod_
,
levels
[
-
1
],
new_module
)
else
:
setattr
(
layer
,
name
,
new_module
)
def
get_op_name
(
module
,
op
):
def
get_op_name
(
module
,
op
):
# get the name of the op relative to the module
# get the name of the op relative to the module
for
name
,
m
in
module
.
named_modules
():
for
name
,
m
in
module
.
named_modules
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment