Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Yuan2.0-M32_pytorch
Commits
6a583c2f
Commit
6a583c2f
authored
Aug 21, 2024
by
chenych
Browse files
update dtk to 24.04.1 and modify README
parent
7d576a9a
Changes
329
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5217 additions
and
0 deletions
+5217
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_gemm.cuh
...rty/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_gemm.cuh
+33
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_gemm_kernel.cuh
...oGPTQ/autogptq_extension/exllamav2/cuda/q_gemm_kernel.cuh
+487
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_gemm_kernel_gptq.cuh
.../autogptq_extension/exllamav2/cuda/q_gemm_kernel_gptq.cuh
+223
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_matrix.cu
...ty/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_matrix.cu
+627
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_matrix.cuh
...y/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_matrix.cuh
+73
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_2.cuh
...utoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_2.cuh
+103
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_3.cuh
...utoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_3.cuh
+169
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_4.cuh
...utoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_4.cuh
+227
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_5.cuh
...utoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_5.cuh
+207
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_6.cuh
...utoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_6.cuh
+44
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_8.cuh
...utoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_8.cuh
+38
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_util.cuh
...GPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_util.cuh
+51
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/util.cuh
...party/AutoGPTQ/autogptq_extension/exllamav2/cuda/util.cuh
+42
-0
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/ext.cpp
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/ext.cpp
+134
-0
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_cuda.cpp
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_cuda.cpp
+80
-0
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_cuda_kernel.cu
.../AutoGPTQ/autogptq_extension/marlin/marlin_cuda_kernel.cu
+855
-0
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_cuda_kernel.cuh
...AutoGPTQ/autogptq_extension/marlin/marlin_cuda_kernel.cuh
+20
-0
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_repack.cu
...party/AutoGPTQ/autogptq_extension/marlin/marlin_repack.cu
+93
-0
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_repack.cuh
...arty/AutoGPTQ/autogptq_extension/marlin/marlin_repack.cuh
+12
-0
3rd_party/AutoGPTQ/autogptq_extension/qigen/generate.py
3rd_party/AutoGPTQ/autogptq_extension/qigen/generate.py
+1699
-0
No files found.
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_gemm.cuh
0 → 100644
View file @
6a583c2f
#ifndef _q_gemm_cuh
#define _q_gemm_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q_matrix.cuh"
void
gemm_half_q_half_cuda
(
cublasHandle_t
cublas_handle
,
const
half
*
a
,
QMatrix
*
b
,
half
*
c
,
int
size_m
,
int
size_n
,
int
size_k
,
bool
clear
=
false
,
half
*
reconstruct
=
NULL
,
bool
force_cuda
=
false
);
void
clear_tensor_cuda
(
half
*
c
,
int
size_m
,
int
size_n
);
#endif
\ No newline at end of file
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_gemm_kernel.cuh
0 → 100644
View file @
6a583c2f
#include "compat.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__forceinline__
__device__
half2
dot22_8
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
half2
dot22_16
(
half2
(
&
dq
)[
8
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
half2
dot22_32
(
half2
(
&
dq
)[
16
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
+=
1
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
float
dot22_8_f
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
__forceinline__
__device__
float
dot22_16_f
(
half2
(
&
dq
)[
8
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
__forceinline__
__device__
float
dot22_32_f
(
half2
(
&
dq
)[
16
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
+=
1
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
typedef
void
(
*
fp_gemm_half_q_half_kernel
)
(
const
half
*
,
const
uint32_t
*
,
const
uint32_t
*
,
const
half
*
,
half
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
uint16_t
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
bool
);
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_kernel
(
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_scale
,
const
half
*
__restrict__
b_q_scale_max
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_n
,
const
int
size_k
,
const
int
groups
,
const
int
groupsize
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
int
rows_8
,
const
int
rows_6
,
const
int
rows_5
,
const
int
rows_4
,
const
int
rows_3
,
const
int
rows_2
,
const
bool
clear
)
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q4_row
b_q_scale_
(
b_q_scale
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
{
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
block_a_ptr
[
t
]
=
a0
;
}
}
// Clear
if
(
n
>=
size_n
)
return
;
if
(
clear
&&
blockIdx
.
z
==
0
)
// && (threadIdx.x & 1) == 0)
{
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
*
((
uint64_t
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
))
=
0
;
}
__syncthreads
();
// Find initial group
int
group
=
offset_k
/
groupsize
;
// Preload scales
float
scales
[
MAX_GROUPS_IN_BLOCK
][
4
];
int
groups_in_block
=
DIVIDE
((
end_k
-
offset_k
),
groupsize
);
for
(
int
g
=
0
;
g
<
groups_in_block
;
g
++
)
{
int
qscales
[
4
];
b_q_scale_
.
item4
(
qscales
,
group
+
g
,
n
);
qscales
[
0
]
++
;
qscales
[
1
]
++
;
qscales
[
2
]
++
;
qscales
[
3
]
++
;
float
maxscale
=
__half2float
(
b_q_scale_max
[
group
+
g
]);
scales
[
g
][
0
]
=
__int2float_rn
(
qscales
[
0
]
*
qscales
[
0
])
*
maxscale
;
scales
[
g
][
1
]
=
__int2float_rn
(
qscales
[
1
]
*
qscales
[
1
])
*
maxscale
;
scales
[
g
][
2
]
=
__int2float_rn
(
qscales
[
2
]
*
qscales
[
2
])
*
maxscale
;
scales
[
g
][
3
]
=
__int2float_rn
(
qscales
[
3
]
*
qscales
[
3
])
*
maxscale
;
}
// a, b offset
int
pre_rows_8
=
min
(
rows_8
,
offset_k
);
int
pre_rows_6
=
offset_k
>
rows_8
?
min
(
rows_6
,
offset_k
)
-
rows_8
:
0
;
int
pre_rows_5
=
offset_k
>
rows_6
?
min
(
rows_5
,
offset_k
)
-
rows_6
:
0
;
int
pre_rows_4
=
offset_k
>
rows_5
?
min
(
rows_4
,
offset_k
)
-
rows_5
:
0
;
int
pre_rows_3
=
offset_k
>
rows_4
?
min
(
rows_3
,
offset_k
)
-
rows_4
:
0
;
int
pre_rows_2
=
offset_k
>
rows_3
?
min
(
rows_2
,
offset_k
)
-
rows_3
:
0
;
int
qk
=
0
;
qk
+=
pre_rows_8
/
32
*
8
;
qk
+=
pre_rows_6
/
32
*
6
;
qk
+=
pre_rows_5
/
32
*
5
;
qk
+=
pre_rows_4
/
32
*
4
;
qk
+=
pre_rows_3
/
32
*
3
;
qk
+=
pre_rows_2
/
32
*
2
;
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
// Initial group
int
scales_idx
=
0
;
float
qs_f0
=
scales
[
scales_idx
][
0
];
float
qs_f1
=
scales
[
scales_idx
][
1
];
float
qs_f2
=
scales
[
scales_idx
][
2
];
float
qs_f3
=
scales
[
scales_idx
][
3
];
int
nextgroup
=
offset_k
+
groupsize
;
// Column result
float
block_c
[
m_count
][
4
]
=
{};
// Dequantize groups
int
k
=
offset_k
;
while
(
k
<
rows_8
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int4
load_int4
[
2
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
4
];
dequant_8bit_8
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
dq
[
0
],
size_n
);
dequant_8bit_8
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
dq
[
1
],
size_n
);
dequant_8bit_8
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
dq
[
2
],
size_n
);
dequant_8bit_8
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
8
;
}
k
+=
32
;
}
while
(
k
<
rows_6
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
int4
load_int4
[
3
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
2
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
8
];
dequant_6bit_16
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
load_int4
[
2
].
x
,
dq
[
0
],
size_n
);
dequant_6bit_16
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
load_int4
[
2
].
y
,
dq
[
1
],
size_n
);
dequant_6bit_16
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
load_int4
[
2
].
z
,
dq
[
2
],
size_n
);
dequant_6bit_16
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
load_int4
[
2
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_16_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_16_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_16_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_16_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
16
;
}
k
+=
32
;
}
while
(
k
<
rows_5
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
1
;
j
++
)
{
int4
load_int4
[
5
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
2
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
3
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
4
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
16
];
dequant_5bit_32
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
load_int4
[
2
].
x
,
load_int4
[
3
].
x
,
load_int4
[
4
].
x
,
dq
[
0
],
size_n
);
dequant_5bit_32
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
load_int4
[
2
].
y
,
load_int4
[
3
].
y
,
load_int4
[
4
].
y
,
dq
[
1
],
size_n
);
dequant_5bit_32
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
load_int4
[
2
].
z
,
load_int4
[
3
].
z
,
load_int4
[
4
].
z
,
dq
[
2
],
size_n
);
dequant_5bit_32
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
load_int4
[
2
].
w
,
load_int4
[
3
].
w
,
load_int4
[
4
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_32_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_32_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_32_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_32_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
32
;
}
k
+=
32
;
}
while
(
k
<
rows_4
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int4
load_int4
[
1
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
4
];
dequant_4bit_8
(
load_int4
[
0
].
x
,
dq
[
0
],
size_n
);
dequant_4bit_8
(
load_int4
[
0
].
y
,
dq
[
1
],
size_n
);
dequant_4bit_8
(
load_int4
[
0
].
z
,
dq
[
2
],
size_n
);
dequant_4bit_8
(
load_int4
[
0
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
8
;
}
k
+=
32
;
}
while
(
k
<
rows_3
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
1
;
j
++
)
{
int4
load_int4
[
3
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
2
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
16
];
dequant_3bit_32
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
load_int4
[
2
].
x
,
dq
[
0
],
size_n
);
dequant_3bit_32
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
load_int4
[
2
].
y
,
dq
[
1
],
size_n
);
dequant_3bit_32
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
load_int4
[
2
].
z
,
dq
[
2
],
size_n
);
dequant_3bit_32
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
load_int4
[
2
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_32_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_32_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_32_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_32_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
32
;
}
k
+=
32
;
}
while
(
k
<
rows_2
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
int4
load_int4
[
1
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
8
];
dequant_2bit_16
(
load_int4
[
0
].
x
,
dq
[
0
],
size_n
);
dequant_2bit_16
(
load_int4
[
0
].
y
,
dq
[
1
],
size_n
);
dequant_2bit_16
(
load_int4
[
0
].
z
,
dq
[
2
],
size_n
);
dequant_2bit_16
(
load_int4
[
0
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_16_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_16_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_16_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_16_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
16
;
}
k
+=
32
;
}
// Accumulate column sums in c
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
0
]),
__float2half_rn
(
block_c
[
m
][
1
]));
half2
result23
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
2
]),
__float2half_rn
(
block_c
[
m
][
3
]));
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
}
}
fp_gemm_half_q_half_kernel
pick_gemm_half_q_half_kernel
(
bool
first_block
,
const
int
m_count
)
{
#if BLOCK_M_SIZE_MAX >= 1
if
(
m_count
==
1
)
return
gemm_half_q_half_kernel
<
true
,
1
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if
(
m_count
==
2
)
return
gemm_half_q_half_kernel
<
true
,
2
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_kernel
<
true
,
3
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_kernel
<
true
,
4
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_kernel
<
true
,
5
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_kernel
<
true
,
6
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_kernel
<
true
,
7
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_kernel
<
true
,
8
>
;
#endif
return
NULL
;
}
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_gemm_kernel_gptq.cuh
0 → 100644
View file @
6a583c2f
#include "compat.cuh"
__forceinline__
__device__
half2
dot22_8
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half2
g_result
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hadd2
(
result
,
g_result
);
}
__forceinline__
__device__
float
dot22_8_f
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
}
typedef
void
(
*
fp_gemm_half_q_half_gptq_kernel
)
(
const
half
*
,
const
uint32_t
*
,
const
uint32_t
*
,
const
half
*
,
half
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
uint16_t
*
,
const
int
,
const
bool
);
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_gptq_kernel
(
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_n
,
const
int
size_k
,
const
int
groups
,
const
int
groupsize
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
int
rows_4
,
const
bool
clear
)
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q4_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
{
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
;
if
(
b_q_perm
)
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
else
a0
=
a_ptr
[
offset_k
+
t
];
block_a_ptr
[
t
]
=
a0
;
}
}
// Zero output
if
(
n
>=
size_n
)
return
;
if
(
clear
&&
blockIdx
.
z
==
0
)
// && (threadIdx.x & 1) == 0)
{
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
*
((
uint64_t
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
))
=
0
;
}
__syncthreads
();
// Find initial group
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// a, b offset
int
qk
=
offset_k
/
(
32
/
4
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
// Initial group
int
zeros
[
4
];
float
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_f
(
scales
,
group
,
n
);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero
((
zeros
[
0
]
+
1
)
&
0x0f
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
((
zeros
[
1
]
+
1
)
&
0x0f
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
((
zeros
[
2
]
+
1
)
&
0x0f
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
((
zeros
[
3
]
+
1
)
&
0x0f
,
z1z16
[
3
],
y1y16
[
3
]);
// __syncthreads();
// Column result
float
block_c
[
m_count
][
4
]
=
{};
// Dequantize and multiply
int
k
=
offset_k
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_f
(
scales
,
group
,
n
);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero
((
zeros
[
0
]
+
1
)
&
0x0f
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
((
zeros
[
1
]
+
1
)
&
0x0f
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
((
zeros
[
2
]
+
1
)
&
0x0f
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
((
zeros
[
3
]
+
1
)
&
0x0f
,
z1z16
[
3
],
y1y16
[
3
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
half2
dq
[
4
][
4
];
dequant_4bit_8_gptq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
#pragma unroll
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
fma
(
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
),
scales
[
0
],
block_c
[
m
][
0
]);
block_c
[
m
][
1
]
=
fma
(
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
),
scales
[
1
],
block_c
[
m
][
1
]);
block_c
[
m
][
2
]
=
fma
(
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
),
scales
[
2
],
block_c
[
m
][
2
]);
block_c
[
m
][
3
]
=
fma
(
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
),
scales
[
3
],
block_c
[
m
][
3
]);
}
b_ptr
+=
size_n
;
a_ptr
+=
8
;
}
k
+=
32
;
}
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
0
]),
__float2half_rn
(
block_c
[
m
][
1
]));
half2
result23
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
2
]),
__float2half_rn
(
block_c
[
m
][
3
]));
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
}
}
fp_gemm_half_q_half_gptq_kernel
pick_gemm_half_q_half_gptq_kernel
(
bool
first_block
,
const
int
m_count
)
{
#if BLOCK_M_SIZE_MAX >= 1
if
(
m_count
==
1
)
return
gemm_half_q_half_gptq_kernel
<
true
,
1
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if
(
m_count
==
2
)
return
gemm_half_q_half_gptq_kernel
<
true
,
2
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_gptq_kernel
<
true
,
3
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_gptq_kernel
<
true
,
4
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_gptq_kernel
<
true
,
5
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_gptq_kernel
<
true
,
6
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_gptq_kernel
<
true
,
7
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_gptq_kernel
<
true
,
8
>
;
#endif
return
NULL
;
}
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_matrix.cu
0 → 100644
View file @
6a583c2f
#include "q_matrix.cuh"
#include "matrix_view.cuh"
#include "util.cuh"
#include "quant/qdq_2.cuh"
#include "quant/qdq_3.cuh"
#include "quant/qdq_4.cuh"
#include "quant/qdq_5.cuh"
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define THREADS_X 32
#define THREADS_Y 32
// Shuffle quantized data on load
__global__
void
shuffle_kernel
(
uint32_t
*
__restrict__
b_q_weight
,
const
int
size_k
,
const
int
size_n
,
const
int
rows_8
,
const
int
rows_6
,
const
int
rows_5
,
const
int
rows_4
,
const
int
rows_3
,
const
int
rows_2
)
{
int
n
=
blockIdx
.
x
*
THREADS_X
+
threadIdx
.
x
;
if
(
n
>=
size_n
)
return
;
int
k
=
0
;
uint32_t
*
b_ptr
=
b_q_weight
+
n
;
while
(
k
<
rows_8
)
{
shuffle_8bit_4
(
b_ptr
,
size_n
);
b_ptr
+=
1
*
size_n
;
k
+=
4
;
}
while
(
k
<
rows_6
)
{
shuffle_6bit_16
(
b_ptr
,
size_n
);
b_ptr
+=
3
*
size_n
;
k
+=
16
;
}
while
(
k
<
rows_5
)
{
shuffle_5bit_32
(
b_ptr
,
size_n
);
b_ptr
+=
5
*
size_n
;
k
+=
32
;
}
while
(
k
<
rows_4
)
{
shuffle_4bit_8
(
b_ptr
,
size_n
);
b_ptr
+=
1
*
size_n
;
k
+=
8
;
}
while
(
k
<
rows_3
)
{
shuffle_3bit_32
(
b_ptr
,
size_n
);
b_ptr
+=
3
*
size_n
;
k
+=
32
;
}
while
(
k
<
rows_2
)
{
shuffle_2bit_16
(
b_ptr
,
size_n
);
b_ptr
+=
1
*
size_n
;
k
+=
16
;
}
}
// QMatrix constructor
QMatrix
::
QMatrix
(
const
int
_device
,
const
int
_height
,
const
int
_width
,
const
int
_groups
,
uint32_t
*
_q_weight
,
uint16_t
*
_q_perm
,
uint16_t
*
_q_invperm
,
uint32_t
*
_q_scale
,
half
*
_q_scale_max
,
uint16_t
*
_q_groups
,
uint32_t
*
_gptq_qzeros
,
half
*
_gptq_scales
,
uint32_t
*
_gptq_g_idx
,
half
*
_temp_dq
)
:
device
(
_device
),
height
(
_height
),
width
(
_width
),
groups
(
_groups
),
temp_dq
(
_temp_dq
)
{
cudaSetDevice
(
device
);
failed
=
false
;
cuda_q_weight
=
_q_weight
;
cuda_q_perm
=
_q_perm
;
cuda_q_invperm
=
_q_invperm
;
cuda_q_scale
=
_q_scale
;
cuda_q_scale_max
=
_q_scale_max
;
cuda_q_groups
=
_q_groups
;
cuda_gptq_qzeros
=
_gptq_qzeros
;
cuda_gptq_scales
=
_gptq_scales
;
is_gptq
=
(
_gptq_qzeros
!=
NULL
);
groupsize
=
1
;
while
(
groupsize
*
groups
<
height
)
groupsize
*=
2
;
// Create group map
rows_8
=
0
;
rows_6
=
0
;
rows_5
=
0
;
rows_4
=
0
;
rows_3
=
0
;
rows_2
=
0
;
if
(
!
is_gptq
)
{
uint16_t
*
cpu_q_groups
=
(
uint16_t
*
)
calloc
(
groups
*
2
,
sizeof
(
uint16_t
));
cudaMemcpy
(
cpu_q_groups
,
cuda_q_groups
,
groups
*
2
*
sizeof
(
uint16_t
),
cudaMemcpyDeviceToHost
);
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
int
bits
=
cpu_q_groups
[
i
*
2
];
if
(
bits
==
8
)
rows_8
+=
groupsize
;
if
(
bits
==
6
)
rows_6
+=
groupsize
;
if
(
bits
==
5
)
rows_5
+=
groupsize
;
if
(
bits
==
4
)
rows_4
+=
groupsize
;
if
(
bits
==
3
)
rows_3
+=
groupsize
;
if
(
bits
==
2
)
rows_2
+=
groupsize
;
}
free
(
cpu_q_groups
);
rows_6
+=
rows_8
;
rows_5
+=
rows_6
;
rows_4
+=
rows_5
;
rows_3
+=
rows_4
;
rows_2
+=
rows_3
;
}
else
{
rows_4
=
height
;
rows_3
=
height
;
rows_2
=
height
;
if
(
_gptq_g_idx
)
{
if
(
!
make_sequential
(
_gptq_g_idx
))
{
failed
=
true
;
//printf("FAIL\n");
return
;
}
}
}
// Shuffle quantized data
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
THREADS_X
;
blockDim
.
y
=
1
;
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
1
;
shuffle_kernel
<<<
gridDim
,
blockDim
>>>
(
cuda_q_weight
,
height
,
width
,
rows_8
,
rows_6
,
rows_5
,
rows_4
,
rows_3
,
rows_2
);
}
QMatrix
::~
QMatrix
()
{
}
// Reconstruct b[k,n] (GPTQ)
__global__
void
reconstruct_gptq_kernel
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
//const uint16_t* __restrict__ b_q_groups,
const
int
size_k
,
const
int
size_n
,
const
int
groupsize
,
const
int
groups
,
half
*
__restrict__
b
,
const
int
rows_4
)
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q4_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
// Preload remapping table
__shared__
uint16_t
perm
[
BLOCK_KN_SIZE
];
int
t
=
threadIdx
.
x
;
if
(
b_q_perm
)
{
if
(
offset_k
+
t
<
size_k
)
perm
[
t
]
=
b_q_perm
[
offset_k
+
t
];
}
// Column
int
n
=
offset_n
+
t
*
4
;
if
(
n
>=
size_n
)
return
;
// Find initial group
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// b offset
int
qk
=
offset_k
/
(
32
/
4
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
// Initial zeros/scale
int
zeros
[
4
];
half2
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero
((
zeros
[
0
]
+
1
)
&
0x0f
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
((
zeros
[
1
]
+
1
)
&
0x0f
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
((
zeros
[
2
]
+
1
)
&
0x0f
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
((
zeros
[
3
]
+
1
)
&
0x0f
,
z1z16
[
3
],
y1y16
[
3
]);
__syncthreads
();
int
k
=
offset_k
;
int
lk
=
0
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero
((
zeros
[
0
]
+
1
)
&
0x0f
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
((
zeros
[
1
]
+
1
)
&
0x0f
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
((
zeros
[
2
]
+
1
)
&
0x0f
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
((
zeros
[
3
]
+
1
)
&
0x0f
,
z1z16
[
3
],
y1y16
[
3
]);
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
{
half2
dq
[
4
][
4
];
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
dequant_4bit_8_gptq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
b_ptr
+=
size_n
;
//half* dqh = (half*)dq;
if
(
b_q_perm
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
perm
[
lk
++
],
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
perm
[
lk
++
],
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
else
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
}
k
+=
32
;
}
}
// Reconstruct b[k,n]
__global__
void
reconstruct_kernel
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_q_scale
,
const
half
*
__restrict__
b_q_scale_max
,
//const uint16_t* __restrict__ b_q_groups,
const
int
size_k
,
const
int
size_n
,
const
int
groupsize
,
const
int
groups
,
half
*
__restrict__
b
,
const
int
rows_8
,
const
int
rows_6
,
const
int
rows_5
,
const
int
rows_4
,
const
int
rows_3
,
const
int
rows_2
)
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q4_row
b_q_scale_
(
b_q_scale
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
;
// Preload remapping table
int
t
=
threadIdx
.
x
;
__shared__
uint16_t
perm
[
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
size_k
)
perm
[
t
]
=
b_q_perm
[
offset_k
+
t
];
// Column
int
n
=
offset_n
+
t
;
if
(
n
>=
size_n
)
return
;
// Find initial group
int
group
=
offset_k
/
groupsize
;
int
pre_rows_8
=
min
(
rows_8
,
offset_k
);
int
pre_rows_6
=
offset_k
>
rows_8
?
min
(
rows_6
,
offset_k
)
-
rows_8
:
0
;
int
pre_rows_5
=
offset_k
>
rows_6
?
min
(
rows_5
,
offset_k
)
-
rows_6
:
0
;
int
pre_rows_4
=
offset_k
>
rows_5
?
min
(
rows_4
,
offset_k
)
-
rows_5
:
0
;
int
pre_rows_3
=
offset_k
>
rows_4
?
min
(
rows_3
,
offset_k
)
-
rows_4
:
0
;
int
pre_rows_2
=
offset_k
>
rows_3
?
min
(
rows_2
,
offset_k
)
-
rows_3
:
0
;
int
qk
=
0
;
qk
+=
pre_rows_8
/
32
*
8
;
qk
+=
pre_rows_6
/
32
*
6
;
qk
+=
pre_rows_5
/
32
*
5
;
qk
+=
pre_rows_4
/
32
*
4
;
qk
+=
pre_rows_3
/
32
*
3
;
qk
+=
pre_rows_2
/
32
*
2
;
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
half
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
half2
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
int
nextgroup
=
offset_k
+
groupsize
;
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
k
=
offset_k
;
int
lk
=
0
;
__syncthreads
();
while
(
k
<
rows_8
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
groupsize
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
{
half2
dq
[
4
];
uint32_t
q_0
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_1
=
*
b_ptr
;
b_ptr
+=
size_n
;
dequant_8bit_8
(
q_0
,
q_1
,
dq
,
size_n
);
for
(
int
j
=
0
;
j
<
4
;
j
++
)
dq
[
j
]
=
__hmul2
(
dq
[
j
],
qs_h2
);
half
*
dqh
=
(
half
*
)
dq
;
for
(
int
j
=
0
;
j
<
8
;
j
++
)
b_
.
set
(
perm
[
lk
++
],
n
,
dqh
[
j
]);
}
k
+=
32
;
}
while
(
k
<
rows_6
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
groupsize
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
2
;
p
++
)
{
half2
dq
[
8
];
uint32_t
q_0
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_1
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_2
=
*
b_ptr
;
b_ptr
+=
size_n
;
dequant_6bit_16
(
q_0
,
q_1
,
q_2
,
dq
,
size_n
);
for
(
int
j
=
0
;
j
<
8
;
j
++
)
dq
[
j
]
=
__hmul2
(
dq
[
j
],
qs_h2
);
half
*
dqh
=
(
half
*
)
dq
;
for
(
int
j
=
0
;
j
<
16
;
j
++
)
b_
.
set
(
perm
[
lk
++
],
n
,
dqh
[
j
]);
}
k
+=
32
;
}
while
(
k
<
rows_5
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
groupsize
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
1
;
p
++
)
{
half2
dq
[
16
];
uint32_t
q_0
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_1
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_2
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_3
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_4
=
*
b_ptr
;
b_ptr
+=
size_n
;
dequant_5bit_32
(
q_0
,
q_1
,
q_2
,
q_3
,
q_4
,
dq
,
size_n
);
for
(
int
j
=
0
;
j
<
16
;
j
++
)
dq
[
j
]
=
__hmul2
(
dq
[
j
],
qs_h2
);
half
*
dqh
=
(
half
*
)
dq
;
for
(
int
j
=
0
;
j
<
32
;
j
++
)
b_
.
set
(
perm
[
lk
++
],
n
,
dqh
[
j
]);
}
k
+=
32
;
}
while
(
k
<
rows_4
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
groupsize
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
{
half2
dq
[
4
];
uint32_t
q_0
=
*
b_ptr
;
b_ptr
+=
size_n
;
dequant_4bit_8
(
q_0
,
dq
,
size_n
);
for
(
int
j
=
0
;
j
<
4
;
j
++
)
dq
[
j
]
=
__hmul2
(
dq
[
j
],
qs_h2
);
half
*
dqh
=
(
half
*
)
dq
;
for
(
int
j
=
0
;
j
<
8
;
j
++
)
b_
.
set
(
perm
[
lk
++
],
n
,
dqh
[
j
]);
}
k
+=
32
;
}
while
(
k
<
rows_3
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
groupsize
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
1
;
p
++
)
{
half2
dq
[
16
];
uint32_t
q_0
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_1
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_2
=
*
b_ptr
;
b_ptr
+=
size_n
;
dequant_3bit_32
(
q_0
,
q_1
,
q_2
,
dq
,
size_n
);
for
(
int
j
=
0
;
j
<
16
;
j
++
)
dq
[
j
]
=
__hmul2
(
dq
[
j
],
qs_h2
);
half
*
dqh
=
(
half
*
)
dq
;
for
(
int
j
=
0
;
j
<
32
;
j
++
)
b_
.
set
(
perm
[
lk
++
],
n
,
dqh
[
j
]);
}
k
+=
32
;
}
while
(
k
<
rows_2
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
groupsize
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
2
;
p
++
)
{
half2
dq
[
8
];
uint32_t
q_0
=
*
b_ptr
;
b_ptr
+=
size_n
;
dequant_2bit_16
(
q_0
,
dq
,
size_n
);
for
(
int
j
=
0
;
j
<
8
;
j
++
)
dq
[
j
]
=
__hmul2
(
dq
[
j
],
qs_h2
);
half
*
dqh
=
(
half
*
)
dq
;
for
(
int
j
=
0
;
j
<
16
;
j
++
)
b_
.
set
(
perm
[
lk
++
],
n
,
dqh
[
j
]);
}
k
+=
32
;
}
}
void
QMatrix
::
reconstruct
(
half
*
out
)
{
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
gridDim
.
y
=
DIVIDE
(
height
,
BLOCK_KN_SIZE
);
if
(
!
is_gptq
)
{
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
reconstruct_kernel
<<<
gridDim
,
blockDim
>>>
(
cuda_q_weight
,
cuda_q_perm
,
cuda_q_scale
,
cuda_q_scale_max
,
//cuda_q_groups,
height
,
width
,
groupsize
,
groups
,
out
,
rows_8
,
rows_6
,
rows_5
,
rows_4
,
rows_3
,
rows_2
);
}
else
{
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
*
4
);
reconstruct_gptq_kernel
<<<
gridDim
,
blockDim
>>>
(
cuda_q_weight
,
cuda_q_perm
,
cuda_gptq_qzeros
,
cuda_gptq_scales
,
//const uint16_t* __restrict__ b_q_groups,
height
,
width
,
groupsize
,
groups
,
out
,
rows_4
);
}
}
__global__
void
make_sequential_kernel
(
const
uint32_t
*
__restrict__
w
,
uint32_t
*
__restrict__
w_new
,
const
uint16_t
*
__restrict__
q_perm
,
const
int
w_height
,
const
int
w_width
)
{
const
uint64_t
*
w2
=
(
uint64_t
*
)
w
;
uint64_t
*
w_new2
=
(
uint64_t
*
)
w_new
;
int
w2_stride
=
w_width
>>
1
;
int
w2_column
=
THREADS_X
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
w2_column
>=
w2_stride
)
return
;
int
w_new2_row
=
blockIdx
.
y
;
int
q_perm_idx
=
w_new2_row
<<
3
;
uint64_t
dst
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
int
source_row
=
q_perm
[
q_perm_idx
++
];
int
w2_row
=
source_row
>>
3
;
int
w2_subrow
=
source_row
&
0x07
;
int
w2_row_shift
=
w2_subrow
<<
2
;
int
wnew2_row_shift
=
i
<<
2
;
uint64_t
src
=
w2
[
w2_row
*
w2_stride
+
w2_column
];
src
>>=
w2_row_shift
;
src
&=
0x0000000f0000000f
;
src
<<=
wnew2_row_shift
;
dst
|=
src
;
}
w_new2
[
w_new2_row
*
w2_stride
+
w2_column
]
=
dst
;
}
bool
QMatrix
::
make_sequential
(
const
uint32_t
*
cpu_g_idx
)
{
uint32_t
*
cuda_new_qweight
=
NULL
;
cudaError_t
err
=
cudaMalloc
(
&
cuda_new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
));
if
(
err
!=
cudaSuccess
)
{
cudaError_t
cuda_status
=
cudaGetLastError
();
// Clear error
return
false
;
}
uint32_t
*
cpu_g_idx_map
=
(
uint32_t
*
)
calloc
(
groups
,
sizeof
(
uint32_t
));
uint32_t
*
cpu_x_map
=
(
uint32_t
*
)
malloc
(
height
*
sizeof
(
uint32_t
));
uint32_t
*
cpu_x_map_inv
=
(
uint32_t
*
)
malloc
(
height
*
sizeof
(
uint32_t
));
// Group histogram
for
(
int
i
=
0
;
i
<
height
;
i
++
)
cpu_g_idx_map
[
cpu_g_idx
[
i
]]
++
;
// Group map
for
(
int
i
=
0
,
acc
=
0
;
i
<
groups
;
i
++
)
{
short
tmp
=
cpu_g_idx_map
[
i
];
cpu_g_idx_map
[
i
]
=
acc
;
acc
+=
tmp
;
}
// X map (inverse)
for
(
int
row
=
0
;
row
<
height
;
row
++
)
{
uint32_t
target_group
=
cpu_g_idx
[
row
];
uint32_t
target_row
=
cpu_g_idx_map
[
target_group
];
cpu_g_idx_map
[
target_group
]
++
;
cpu_x_map_inv
[
row
]
=
target_row
;
}
// X map
for
(
int
row
=
0
;
row
<
height
;
row
++
)
cpu_x_map
[
cpu_x_map_inv
[
row
]]
=
row
;
// Reduce to uint16_t
uint16_t
*
cpu_x_map16
=
(
uint16_t
*
)
cpu_x_map
;
uint16_t
*
cpu_x_map_inv16
=
(
uint16_t
*
)
cpu_x_map_inv
;
for
(
int
row
=
0
;
row
<
height
;
row
++
)
cpu_x_map16
[
row
]
=
(
uint16_t
)
cpu_x_map
[
row
];
for
(
int
row
=
0
;
row
<
height
;
row
++
)
cpu_x_map_inv16
[
row
]
=
(
uint16_t
)
cpu_x_map_inv
[
row
];
// Move to CUDA
cudaMemcpyAsync
(
cuda_q_perm
,
cpu_x_map16
,
height
*
sizeof
(
uint16_t
),
cudaMemcpyHostToDevice
);
cudaMemcpyAsync
(
cuda_q_invperm
,
cpu_x_map_inv16
,
height
*
sizeof
(
uint16_t
),
cudaMemcpyHostToDevice
);
// Rearrange rows in w
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
THREADS_X
;
blockDim
.
y
=
1
;
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
height
/
8
;
make_sequential_kernel
<<<
gridDim
,
blockDim
>>>
(
cuda_q_weight
,
cuda_new_qweight
,
cuda_q_perm
,
height
/
8
,
width
);
// Replace qweights
cudaMemcpyAsync
(
cuda_q_weight
,
cuda_new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
),
cudaMemcpyDeviceToDevice
);
// Cleanup
cudaDeviceSynchronize
();
cudaFree
(
cuda_new_qweight
);
free
(
cpu_g_idx_map
);
free
(
cpu_x_map
);
free
(
cpu_x_map_inv
);
return
true
;
}
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/q_matrix.cuh
0 → 100644
View file @
6a583c2f
#ifndef _q_matrix_cuh
#define _q_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#define MAX_SUPERGROUPS 16
class
QMatrix
{
public:
int
device
;
bool
is_gptq
;
int
height
;
int
width
;
int
groups
;
int
groupsize
;
int
rows_8
;
int
rows_6
;
int
rows_5
;
int
rows_4
;
int
rows_3
;
int
rows_2
;
uint32_t
*
cuda_q_weight
=
NULL
;
uint16_t
*
cuda_q_perm
=
NULL
;
uint16_t
*
cuda_q_invperm
=
NULL
;
uint32_t
*
cuda_q_scale
=
NULL
;
half
*
cuda_q_scale_max
=
NULL
;
uint16_t
*
cuda_q_groups
=
NULL
;
uint32_t
*
cuda_gptq_qzeros
=
NULL
;
half
*
cuda_gptq_scales
=
NULL
;
half
*
temp_dq
;
bool
failed
;
QMatrix
(
const
int
_device
,
const
int
_height
,
const
int
_width
,
const
int
_groups
,
uint32_t
*
_q_weight
,
uint16_t
*
_q_perm
,
uint16_t
*
_q_invperm
,
uint32_t
*
_q_scale
,
half
*
_q_scale_max
,
uint16_t
*
_q_groups
,
uint32_t
*
_gptq_qzeros
,
half
*
_gptq_scales
,
uint32_t
*
_gptq_g_idx
,
half
*
_temp_dq
);
~
QMatrix
();
void
reconstruct
(
half
*
out
);
bool
make_sequential
(
const
uint32_t
*
cpu_g_idx
);
private:
};
#endif
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_2.cuh
0 → 100644
View file @
6a583c2f
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_2BIT == 1
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__
__device__
void
shuffle_2bit_16
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x03
;
uint32_t
qa1
=
(
qa
&
0x0c
)
>>
2
;
qa
>>=
4
;
qb
|=
(
qa1
<<
(
i
*
2
+
16
));
qb
|=
(
qa0
<<
(
i
*
2
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_2bit_16
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
8
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y4_
=
__float2half_rn
(
1.0
f
/
4.0
f
);
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y4
=
__halves2half2
(
y4_
,
y4_
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
2.0
f
);
const
half
z4_
=
__float2half_rn
(
-
1024.0
f
/
4.0
f
-
2.0
f
);
const
half
z16_
=
__float2half_rn
(
-
1024.0
f
/
16.0
f
-
2.0
f
);
const
half
z64_
=
__float2half_rn
(
-
1024.0
f
/
64.0
f
-
2.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z4
=
__halves2half2
(
z4_
,
z4_
);
const
half2
z16
=
__halves2half2
(
z16_
,
z16_
);
const
half2
z64
=
__halves2half2
(
z64_
,
z64_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32
q2
((
qa
&
0x00300030
)
|
c0
);
// half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32
q3
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 64 + 1024
qa
>>=
8
;
half2_uint32
q4
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 8], q[ 8]) + 1024
half2_uint32
q5
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[10], q[11]) * 4 + 1024
half2_uint32
q6
((
qa
&
0x00300030
)
|
c0
);
// half2(q[12], q[13]) * 16 + 1024
half2_uint32
q7
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[14], q[15]) * 64 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y4
,
z4
);
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y16
,
z16
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y64
,
z64
);
dq
[
4
]
=
__hadd2
(
q4
.
as_half2
,
z1
);
dq
[
5
]
=
__hfma2
(
q5
.
as_half2
,
y4
,
z4
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y16
,
z16
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y64
,
z64
);
}
#else
__forceinline__
__device__
void
shuffle_2bit_16
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_2bit_16
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
8
],
int
stride
)
{
half
dqh
[
16
];
for
(
int
i
=
0
;
i
<
16
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
2
,
0x03
),
2
);
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
\ No newline at end of file
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_3.cuh
0 → 100644
View file @
6a583c2f
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_3BIT == 1
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__
__device__
void
shuffle_3bit_32
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
*
stride
];
uint32_t
qb
=
q
[
1
*
stride
];
uint32_t
qc
=
q
[
2
*
stride
];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t
qd
=
qc
>>
26
;
qc
<<=
4
;
qc
|=
qb
>>
28
;
qb
<<=
2
;
qb
|=
qa
>>
30
;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t
za
=
0
;
uint32_t
zb
=
0
;
uint32_t
zc
=
0
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qa
&
0x07
;
uint32_t
t1
=
(
qa
&
0x38
)
>>
3
;
qa
>>=
6
;
za
|=
(
t0
<<
(
i
*
3
));
za
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qb
&
0x07
;
uint32_t
t1
=
(
qb
&
0x38
)
>>
3
;
qb
>>=
6
;
zb
|=
(
t0
<<
(
i
*
3
));
zb
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qc
&
0x07
;
uint32_t
t1
=
(
qc
&
0x38
)
>>
3
;
qc
>>=
6
;
zc
|=
(
t0
<<
(
i
*
3
));
zc
|=
(
t1
<<
(
i
*
3
+
16
));
}
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za
|=
((
qd
&
0x01
)
>>
0
)
<<
15
;
zb
|=
((
qd
&
0x02
)
>>
1
)
<<
15
;
zc
|=
((
qd
&
0x04
)
>>
2
)
<<
15
;
za
|=
((
qd
&
0x08
)
>>
3
)
<<
31
;
zb
|=
((
qd
&
0x10
)
>>
4
)
<<
31
;
zc
|=
((
qd
&
0x20
)
>>
5
)
<<
31
;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q
[
0
*
stride
]
=
za
;
q
[
1
*
stride
]
=
zb
;
q
[
2
*
stride
]
=
zc
;
}
__forceinline__
__device__
void
dequant_3bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
16
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y8_
=
__float2half_rn
(
1.0
f
/
8.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y8
=
__halves2half2
(
y8_
,
y8_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
4.0
f
);
const
half
z8_
=
__float2half_rn
(
-
1024.0
f
/
8.0
f
-
4.0
f
);
const
half
z64_
=
__float2half_rn
(
-
1024.0
f
/
64.0
f
-
4.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z8
=
__halves2half2
(
z8_
,
z8_
);
const
half2
z64
=
__halves2half2
(
z64_
,
z64_
);
uint32_t
qa
=
q_0
;
uint32_t
qb
=
q_1
;
uint32_t
qc
=
q_2
;
half2_uint32
q0
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 2], q[ 3]) * 8 + 1024
qa
>>=
6
;
half2_uint32
q2
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32
q4
((
qa
&
0x01c001c0
)
|
c0
);
// half2(q[ 8], q[ 9]) * 64 + 1024
qa
>>=
9
;
qa
&=
0x00010001
;
half2_uint32
q5
((
qb
&
0x00070007
)
|
c0
);
// half2(q[10], q[11]) + 1024
half2_uint32
q6
((
qb
&
0x00380038
)
|
c0
);
// half2(q[12], q[13]) * 8 + 1024
qb
>>=
6
;
half2_uint32
q7
((
qb
&
0x00070007
)
|
c0
);
// half2(q[14], q[15]) + 1024
half2_uint32
q8
((
qb
&
0x00380038
)
|
c0
);
// half2(q[16], q[17]) * 8 + 1024
half2_uint32
q9
((
qb
&
0x01c001c0
)
|
c0
);
// half2(q[18], q[19]) * 64 + 1024
qb
>>=
8
;
qb
&=
0x00020002
;
half2_uint32
q10
((
qc
&
0x00070007
)
|
c0
);
// half2(q[20], q[21]) + 1024
half2_uint32
q11
((
qc
&
0x00380038
)
|
c0
);
// half2(q[22], q[23]) * 8 + 1024
qc
>>=
6
;
half2_uint32
q12
((
qc
&
0x00070007
)
|
c0
);
// half2(q[24], q[25]) + 1024
half2_uint32
q13
((
qc
&
0x00380038
)
|
c0
);
// half2(q[26], q[27]) * 8 + 1024
half2_uint32
q14
((
qc
&
0x01c001c0
)
|
c0
);
// half2(q[28], q[29]) * 64 + 1024
qc
>>=
7
;
qc
&=
0x00040004
;
half2_uint32
q15
((
qa
|
qb
|
qc
)
|
c0
);
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y8
,
z8
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y8
,
z8
);
dq
[
4
]
=
__hfma2
(
q4
.
as_half2
,
y64
,
z64
);
dq
[
5
]
=
__hadd2
(
q5
.
as_half2
,
z1
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y8
,
z8
);
dq
[
7
]
=
__hadd2
(
q7
.
as_half2
,
z1
);
dq
[
8
]
=
__hfma2
(
q8
.
as_half2
,
y8
,
z8
);
dq
[
9
]
=
__hfma2
(
q9
.
as_half2
,
y64
,
z64
);
dq
[
10
]
=
__hadd2
(
q10
.
as_half2
,
z1
);
dq
[
11
]
=
__hfma2
(
q11
.
as_half2
,
y8
,
z8
);
dq
[
12
]
=
__hadd2
(
q12
.
as_half2
,
z1
);
dq
[
13
]
=
__hfma2
(
q13
.
as_half2
,
y8
,
z8
);
dq
[
14
]
=
__hfma2
(
q14
.
as_half2
,
y64
,
z64
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
#else
__forceinline__
__device__
void
shuffle_3bit_32
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_3bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
16
],
int
stride
)
{
half
dqh
[
32
];
for
(
int
i
=
0
;
i
<
10
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
3
,
0x07
),
4
);
dqh
[
10
]
=
dq_ns
(
exb
(
q_1
,
q_0
,
30
,
0x07
),
4
);
for
(
int
i
=
0
;
i
<
10
;
i
++
)
dqh
[
11
+
i
]
=
dq_ns
(
exb
(
q_1
,
i
*
3
+
1
,
0x07
),
4
);
dqh
[
21
]
=
dq_ns
(
exb
(
q_2
,
q_1
,
31
,
0x07
),
4
);
for
(
int
i
=
0
;
i
<
10
;
i
++
)
dqh
[
22
+
i
]
=
dq_ns
(
exb
(
q_2
,
i
*
3
+
2
,
0x07
),
4
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_4.cuh
0 → 100644
View file @
6a583c2f
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_4BIT == 1
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x0f
;
uint32_t
qa1
=
(
qa
&
0xf0
)
>>
4
;
qa
>>=
8
;
qb
|=
(
qa1
<<
(
i
*
4
+
16
));
qb
|=
(
qa0
<<
(
i
*
4
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
8.0
f
);
const
half
z16_
=
__float2half_rn
(
-
1024.0
f
/
16.0
f
-
8.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z16
=
__halves2half2
(
z16_
,
z16_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 2], q[ 3]) * 16 + 1024
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 16 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y16
,
z16
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y16
,
z16
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
half2
scale2
=
__half2half2
(
scale
);
z1z16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
z1
.
as_half
));
z1z16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
z16
));
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
y1
));
y1y16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
y16
));
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
z1z16
[
0
]
=
__half2half2
(
z1
.
as_half
);
z1z16
[
1
]
=
__half2half2
(
z16
);
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__half2half2
(
y1
);
y1y16
[
1
]
=
__half2half2
(
y16
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
],
int
stride
,
bool
scaled
)
{
const
uint32_t
c0
=
0x64006400
;
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[0] + 1024, q[1] + 1024 )
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[4] + 1024, q[5] + 1024 )
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
q0
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
// half2( q[0] * s - z * s, q[1] * s - z * s)
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] * s - z * s, q[3] * s - z * s)
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
}
else
{
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1z16
[
0
]);
// half2( q[0] - z, q[1] - z )
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] - z, q[3] - z )
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1z16
[
0
]);
// half2( q[4] - z, q[5] - z )
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[6] - z, q[7] - z )
}
}
#else
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
4
,
0x0f
),
8
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z
=
__hmul
(
z
,
scale
);
z1
[
0
]
=
__half2half2
(
z
);
y1
[
0
]
=
__half2half2
(
scale
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z1
[
0
]
=
__half2half2
(
z
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
],
int
stride
,
bool
scaled
)
{
half2
dqh2
[
8
];
uint32_t
qa
=
q_0
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
half
d0
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
half
d1
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
dqh2
[
i
]
=
__halves2half2
(
d0
,
d1
);
}
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
dqh2
[
0
],
y1
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hfma2
(
dqh2
[
1
],
y1
[
0
],
z1
[
0
]);
dq
[
2
]
=
__hfma2
(
dqh2
[
2
],
y1
[
0
],
z1
[
0
]);
dq
[
3
]
=
__hfma2
(
dqh2
[
3
],
y1
[
0
],
z1
[
0
]);
}
else
{
dq
[
0
]
=
__hadd2
(
dqh2
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hadd2
(
dqh2
[
1
],
z1
[
0
]);
dq
[
2
]
=
__hadd2
(
dqh2
[
2
],
z1
[
0
]);
dq
[
3
]
=
__hadd2
(
dqh2
[
3
],
z1
[
0
]);
}
}
#endif
#endif
\ No newline at end of file
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_5.cuh
0 → 100644
View file @
6a583c2f
#ifndef _qdq_5_cuh
#define _qdq_5_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_5BIT == 1
// Permutation:
//
// v5555533 33311111 u4444422 22200000 (u, v lsb)
// vbbbbb99 99977777 uaaaaa88 88866666
// vhhhhhff fffddddd ugggggee eeeccccc
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
// vtttttrr rrrppppp usssssqq qqqooooo
__forceinline__
__device__
void
shuffle_5bit_32
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
*
stride
];
uint32_t
qb
=
q
[
1
*
stride
];
uint32_t
qc
=
q
[
2
*
stride
];
uint32_t
qd
=
q
[
3
*
stride
];
uint32_t
qe
=
q
[
4
*
stride
];
// qa: 66555554 44443333 32222211 11100000
// qb: ccccbbbb baaaaa99 99988888 77777666
// qc: jiiiiihh hhhggggg fffffeee eedddddc
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
uint32_t
qf
=
qe
>>
22
;
qe
<<=
8
;
qe
|=
qd
>>
24
;
qd
<<=
6
;
qd
|=
qc
>>
26
;
qc
<<=
4
;
qc
|=
qb
>>
28
;
qb
<<=
2
;
qb
|=
qa
>>
30
;
// qa: 555554 44443333 32222211 11100000
// qb: bbbbba aaaa9999 98888877 77766666
// qc: hhhhhg ggggffff feeeeedd dddccccc
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
// qf: vv vvvuuuuu
uint32_t
za
=
0
;
uint32_t
zb
=
0
;
uint32_t
zc
=
0
;
uint32_t
zd
=
0
;
uint32_t
ze
=
0
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qa
&
0x1f
;
uint32_t
t1
=
(
qa
&
0x3e0
)
>>
5
;
qa
>>=
10
;
za
|=
(
t0
<<
(
i
*
5
));
za
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qb
&
0x1f
;
uint32_t
t1
=
(
qb
&
0x3e0
)
>>
5
;
qb
>>=
10
;
zb
|=
(
t0
<<
(
i
*
5
));
zb
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qc
&
0x1f
;
uint32_t
t1
=
(
qc
&
0x3e0
)
>>
5
;
qc
>>=
10
;
zc
|=
(
t0
<<
(
i
*
5
));
zc
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qd
&
0x1f
;
uint32_t
t1
=
(
qd
&
0x3e0
)
>>
5
;
qd
>>=
10
;
zd
|=
(
t0
<<
(
i
*
5
));
zd
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qe
&
0x1f
;
uint32_t
t1
=
(
qe
&
0x3e0
)
>>
5
;
qe
>>=
10
;
ze
|=
(
t0
<<
(
i
*
5
));
ze
|=
(
t1
<<
(
i
*
5
+
16
));
}
// za: 5555533 33311111 4444422 22200000
// zb: bbbbb99 99977777 aaaaa88 88866666
// zc: hhhhhff fffddddd gggggee eeeccccc
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
// ze: tttttrr rrrppppp sssssqq qqqooooo
// qf: vv vvvuuuuu
za
|=
((
qf
&
0x001
)
>>
0
)
<<
15
;
zb
|=
((
qf
&
0x002
)
>>
1
)
<<
15
;
zc
|=
((
qf
&
0x004
)
>>
2
)
<<
15
;
zd
|=
((
qf
&
0x008
)
>>
3
)
<<
15
;
ze
|=
((
qf
&
0x010
)
>>
4
)
<<
15
;
za
|=
((
qf
&
0x020
)
>>
5
)
<<
31
;
zb
|=
((
qf
&
0x040
)
>>
6
)
<<
31
;
zc
|=
((
qf
&
0x080
)
>>
7
)
<<
31
;
zd
|=
((
qf
&
0x100
)
>>
8
)
<<
31
;
ze
|=
((
qf
&
0x200
)
>>
9
)
<<
31
;
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
// zb: vbbbbb99 99977777 uaaaaa88 88866666
// zc: vhhhhhff fffddddd ugggggee eeeccccc
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
// ze: vtttttrr rrrppppp usssssqq qqqooooo
q
[
0
*
stride
]
=
za
;
q
[
1
*
stride
]
=
zb
;
q
[
2
*
stride
]
=
zc
;
q
[
3
*
stride
]
=
zd
;
q
[
4
*
stride
]
=
ze
;
}
__forceinline__
__device__
void
dequant_5bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
const
uint32_t
q_3
,
const
uint32_t
q_4
,
half2
(
&
dq
)[
16
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y32_
=
__float2half_rn
(
1.0
f
/
32.0
f
);
const
half2
y32
=
__halves2half2
(
y32_
,
y32_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
16.0
f
);
const
half
z32_
=
__float2half_rn
(
-
1024.0
f
/
32.0
f
-
16.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z32
=
__halves2half2
(
z32_
,
z32_
);
uint32_t
qa
=
q_0
;
uint32_t
qb
=
q_1
;
uint32_t
qc
=
q_2
;
uint32_t
qd
=
q_3
;
uint32_t
qe
=
q_4
;
half2_uint32
q0
((
qa
&
0x001f001f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x03e003e0
)
|
c0
);
// half2(q[ 2], q[ 3]) * 32 + 1024
qa
>>=
10
;
half2_uint32
q2
((
qa
&
0x001f001f
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
qa
>>=
5
;
qa
&=
0x00010001
;
half2_uint32
q3
((
qb
&
0x001f001f
)
|
c0
);
// half2(q[ 6], q[ 7]) + 1024
half2_uint32
q4
((
qb
&
0x03e003e0
)
|
c0
);
// half2(q[ 8], q[ 9]) * 32 + 1024
qb
>>=
10
;
half2_uint32
q5
((
qb
&
0x001f001f
)
|
c0
);
// half2(q[10], q[11]) + 1024
qb
>>=
4
;
qb
&=
0x00020002
;
half2_uint32
q6
((
qc
&
0x001f001f
)
|
c0
);
// half2(q[12], q[13]) + 1024
half2_uint32
q7
((
qc
&
0x03e003e0
)
|
c0
);
// half2(q[14], q[15]) * 32 + 1024
qc
>>=
10
;
half2_uint32
q8
((
qc
&
0x001f001f
)
|
c0
);
// half2(q[16], q[17]) + 1024
qc
>>=
3
;
qc
&=
0x00040004
;
half2_uint32
q9
((
qd
&
0x001f001f
)
|
c0
);
// half2(q[18], q[19]) + 1024
half2_uint32
q10
((
qd
&
0x03e003e0
)
|
c0
);
// half2(q[20], q[21]) * 32 + 1024
qd
>>=
10
;
half2_uint32
q11
((
qd
&
0x001f001f
)
|
c0
);
// half2(q[22], q[23]) + 1024
qd
>>=
2
;
qd
&=
0x00080008
;
half2_uint32
q12
((
qe
&
0x001f001f
)
|
c0
);
// half2(q[24], q[25]) + 1024
half2_uint32
q13
((
qe
&
0x03e003e0
)
|
c0
);
// half2(q[26], q[27]) * 32 + 1024
qe
>>=
10
;
half2_uint32
q14
((
qe
&
0x001f001f
)
|
c0
);
// half2(q[28], q[29]) + 1024
qe
>>=
1
;
qe
&=
0x00100010
;
half2_uint32
q15
((
qa
|
qb
|
qc
|
qd
|
qe
)
|
c0
);
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y32
,
z32
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hadd2
(
q3
.
as_half2
,
z1
);
dq
[
4
]
=
__hfma2
(
q4
.
as_half2
,
y32
,
z32
);
dq
[
5
]
=
__hadd2
(
q5
.
as_half2
,
z1
);
dq
[
6
]
=
__hadd2
(
q6
.
as_half2
,
z1
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y32
,
z32
);
dq
[
8
]
=
__hadd2
(
q8
.
as_half2
,
z1
);
dq
[
9
]
=
__hadd2
(
q9
.
as_half2
,
z1
);
dq
[
10
]
=
__hfma2
(
q10
.
as_half2
,
y32
,
z32
);
dq
[
11
]
=
__hadd2
(
q11
.
as_half2
,
z1
);
dq
[
12
]
=
__hadd2
(
q12
.
as_half2
,
z1
);
dq
[
13
]
=
__hfma2
(
q13
.
as_half2
,
y32
,
z32
);
dq
[
14
]
=
__hadd2
(
q14
.
as_half2
,
z1
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
#else
__forceinline__
__device__
void
shuffle_5bit_32
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_5bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
const
uint32_t
q_3
,
const
uint32_t
q_4
,
half2
(
&
dq
)[
16
],
int
stride
)
{
half
dqh
[
32
];
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
5
,
0x1f
),
16
);
dqh
[
6
]
=
dq_ns
(
exb
(
q_1
,
q_0
,
30
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
7
+
i
]
=
dq_ns
(
exb
(
q_1
,
i
*
5
+
3
,
0x1f
),
16
);
dqh
[
12
]
=
dq_ns
(
exb
(
q_2
,
q_1
,
28
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
13
+
i
]
=
dq_ns
(
exb
(
q_2
,
i
*
5
+
1
,
0x1f
),
16
);
dqh
[
19
]
=
dq_ns
(
exb
(
q_3
,
q_2
,
31
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
20
+
i
]
=
dq_ns
(
exb
(
q_3
,
i
*
5
+
4
,
0x1f
),
16
);
dqh
[
25
]
=
dq_ns
(
exb
(
q_4
,
q_3
,
29
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
26
+
i
]
=
dq_ns
(
exb
(
q_4
,
i
*
5
+
2
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
\ No newline at end of file
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_6.cuh
0 → 100644
View file @
6a583c2f
#ifndef _qdq_6_cuh
#define _qdq_6_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_6BIT == 1
// Not implemented
#else
__forceinline__
__device__
void
shuffle_6bit_16
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_6bit_16
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
8
],
int
stride
)
{
half
dqh
[
16
];
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
6
,
0x3f
),
32
);
dqh
[
5
]
=
dq_ns
(
exb
(
q_1
,
q_0
,
30
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
6
+
i
]
=
dq_ns
(
exb
(
q_1
,
i
*
6
+
4
,
0x3f
),
32
);
dqh
[
10
]
=
dq_ns
(
exb
(
q_2
,
q_1
,
28
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
11
+
i
]
=
dq_ns
(
exb
(
q_2
,
i
*
6
+
2
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_8.cuh
0 → 100644
View file @
6a583c2f
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_8BIT == 1
// Not implemented
#else
__forceinline__
__device__
void
shuffle_8bit_4
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_8bit_8
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
8
,
0xff
),
128
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
+
4
]
=
dq_ns
(
exb
(
q_1
,
i
*
8
,
0xff
),
128
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
\ No newline at end of file
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/quant/qdq_util.cuh
0 → 100644
View file @
6a583c2f
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
union
half2_uint32
{
uint32_t
as_uint32
;
half2
as_half2
;
__device__
half2_uint32
(
uint32_t
val
)
:
as_uint32
(
val
)
{}
__device__
half2_uint32
(
half2
val
)
:
as_half2
(
val
)
{}
};
union
half_uint16
{
uint16_t
as_uint16
;
half
as_half
;
__device__
half_uint16
(
uint16_t
val
)
:
as_uint16
(
val
)
{}
__device__
half_uint16
(
half
val
)
:
as_half
(
val
)
{}
};
// Max_scale premultiplied by 1/256
__forceinline__
__device__
half
dq_scale
(
const
int
qs
,
const
half
max_scale
)
{
int
qs_i
=
qs
+
1
;
half
qs_h
=
__int2half_rn
(
qs_i
*
qs_i
);
qs_h
=
__hmul
(
qs_h
,
max_scale
);
return
qs_h
;
}
__forceinline__
__device__
half
dq
(
const
int
q
,
const
int
qzero
,
const
half
scale
)
{
return
__hmul
(
__int2half_rn
(
q
-
qzero
),
scale
);
}
__forceinline__
__device__
half
dq_ns
(
const
int
q
,
const
int
qzero
)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return
__int2half_rn
(
q
-
qzero
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)((
q
>>
shift
)
&
mask
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q1
,
const
uint32_t
q0
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)(
__funnelshift_rc
(
q0
,
q1
,
shift
)
&
mask
);
}
#endif
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/cuda/util.cuh
0 → 100644
View file @
6a583c2f
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
__forceinline__
__device__
half
dq_scale_
(
const
int
qs
,
const
half
max_scale
)
{
half
qs_h
=
__hmul
(
__int2half_rn
(
qs
+
1
),
__float2half_rn
(
1.0
f
/
16.0
f
));
qs_h
=
__hmul
(
qs_h
,
qs_h
);
qs_h
=
__hmul
(
qs_h
,
max_scale
);
return
qs_h
;
}
__forceinline__
__device__
float
clamp
(
float
x
,
float
a
,
float
b
)
{
return
fmaxf
(
a
,
fminf
(
b
,
x
));
}
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
inline
void
gpu_assert
(
cudaError_t
code
,
const
char
*
file
,
int
line
,
bool
abort
=
true
)
{
if
(
code
!=
cudaSuccess
)
{
fprintf
(
stderr
,
"CUDA error: %s %s %d
\n
"
,
cudaGetErrorString
(
code
),
file
,
line
);
if
(
abort
)
exit
(
code
);
}
}
3rd_party/AutoGPTQ/autogptq_extension/exllamav2/ext.cpp
0 → 100644
View file @
6a583c2f
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "config.h"
#include "cuda/q_matrix.cuh"
#include "cuda/q_gemm.cuh"
#include "cpp/util.h"
// Some decluttering macros
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
// Quant matrix
uintptr_t
make_q_matrix
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
torch
::
Tensor
q_invperm
,
torch
::
Tensor
q_scale
,
torch
::
Tensor
q_scale_max
,
torch
::
Tensor
q_groups
,
torch
::
Tensor
gptq_qzeros
,
torch
::
Tensor
gptq_scales
,
torch
::
Tensor
gptq_g_idx
,
torch
::
Tensor
temp_dq
)
{
TORCH_CHECK_DTYPE
(
q_weight
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
q_perm
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
q_invperm
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
q_scale
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
q_scale_max
,
kHalf
);
TORCH_CHECK_DTYPE_OPT
(
q_groups
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
gptq_qzeros
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
gptq_scales
,
kHalf
);
TORCH_CHECK_DTYPE_OPT
(
gptq_g_idx
,
kInt
);
TORCH_CHECK_SHAPES
(
q_perm
,
0
,
q_invperm
,
0
,
1
);
int
device
=
q_weight
.
device
().
index
();
int
width
=
q_weight
.
size
(
1
);
int
groups
;
int
height
;
if
(
!
q_scale
.
device
().
is_meta
())
{
TORCH_CHECK_SHAPES
(
q_weight
,
1
,
q_scale
,
1
,
8
);
TORCH_CHECK_SHAPES
(
q_scale_max
,
0
,
q_scale
,
0
,
1
);
groups
=
q_scale
.
size
(
0
);
height
=
q_invperm
.
size
(
0
);
}
else
{
TORCH_CHECK_SHAPES
(
q_weight
,
1
,
gptq_qzeros
,
1
,
8
);
TORCH_CHECK_SHAPES
(
q_weight
,
1
,
gptq_scales
,
1
,
1
);
groups
=
gptq_qzeros
.
size
(
0
);
height
=
q_weight
.
size
(
0
)
*
8
;
}
TORCH_CHECK
(
temp_dq
.
size
(
0
)
>=
width
*
height
,
"Insufficient size of temp_dq buffer"
)
QMatrix
*
m
=
new
QMatrix
(
device
,
height
,
width
,
groups
,
(
uint32_t
*
)
q_weight
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_perm
.
data_ptr
(),
q_invperm
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_invperm
.
data_ptr
(),
q_scale
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
q_scale
.
data_ptr
(),
q_scale_max
.
device
().
is_meta
()
?
NULL
:
(
half
*
)
q_scale_max
.
data_ptr
(),
q_groups
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_groups
.
data_ptr
(),
gptq_qzeros
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
gptq_qzeros
.
data_ptr
(),
gptq_scales
.
device
().
is_meta
()
?
NULL
:
(
half
*
)
gptq_scales
.
data_ptr
(),
gptq_g_idx
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
gptq_g_idx
.
data_ptr
(),
(
half
*
)
temp_dq
.
data_ptr
()
);
return
reinterpret_cast
<
uintptr_t
>
(
m
);
}
void
gemm_half_q_half
(
torch
::
Tensor
a
,
uintptr_t
b
,
torch
::
Tensor
c
,
bool
force_cuda
)
{
QMatrix
*
qm
=
reinterpret_cast
<
QMatrix
*>
(
b
);
TORCH_CHECK_DTYPE
(
a
,
kHalf
);
TORCH_CHECK_DTYPE
(
c
,
kHalf
);
TORCH_CHECK_SHAPES
(
a
,
0
,
c
,
0
,
1
);
TORCH_CHECK
(
qm
->
height
==
a
.
size
(
1
),
"a and b have incompatible shapes"
)
TORCH_CHECK
(
qm
->
width
==
c
.
size
(
1
),
"b and c have incompatible shapes"
)
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
gemm_half_q_half_cuda
(
at
::
cuda
::
getCurrentCUDABlasHandle
(),
(
const
half
*
)
a
.
data_ptr
(),
qm
,
(
half
*
)
c
.
data_ptr
(),
c
.
size
(
0
),
// m
c
.
size
(
1
),
// n
a
.
size
(
1
),
// k
true
,
NULL
,
force_cuda
);
}
// Bindings
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"make_q_matrix"
,
&
make_q_matrix
,
"make_q_matrix"
);
m
.
def
(
"gemm_half_q_half"
,
&
gemm_half_q_half
,
"gemm_half_q_half"
);
}
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_cuda.cpp
0 → 100644
View file @
6a583c2f
/*
* Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/python.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include "marlin_cuda_kernel.cuh"
#include "marlin_repack.cuh"
const
int
ERR_PROB_SHAPE
=
1
;
const
int
ERR_KERN_SHAPE
=
2
;
void
mul
(
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
s
,
torch
::
Tensor
&
workspace
,
int
thread_k
=
-
1
,
int
thread_n
=
-
1
,
int
sms
=
-
1
,
int
max_par
=
8
)
{
int
prob_m
=
A
.
size
(
0
);
int
prob_n
=
C
.
size
(
1
);
int
prob_k
=
A
.
size
(
1
);
int
groupsize
=
(
s
.
size
(
0
)
==
1
)
?
-
1
:
prob_k
/
s
.
size
(
0
);
if
(
groupsize
!=
-
1
&&
groupsize
*
s
.
size
(
0
)
!=
prob_k
)
AT_ERROR
(
"k="
,
prob_k
,
" not compatible with "
,
s
.
size
(
0
),
" groups."
);
if
(
workspace
.
numel
()
<
prob_n
/
128
*
max_par
)
AT_ERROR
(
"workspace must be of size at least "
,
prob_n
/
128
*
max_par
,
"."
);
int
dev
=
A
.
get_device
();
int
err
=
marlin_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
s
.
data_ptr
(),
prob_m
,
prob_n
,
prob_k
,
workspace
.
data_ptr
(),
groupsize
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
);
if
(
err
==
ERR_PROB_SHAPE
)
{
AT_ERROR
(
"Problem (m="
,
prob_m
,
", n="
,
prob_n
,
", k="
,
prob_k
,
")"
,
" not compatible with thread_k="
,
thread_k
,
", thread_n="
,
thread_n
,
"."
);
}
else
if
(
err
==
ERR_KERN_SHAPE
)
{
AT_ERROR
(
"No kernel implementation for thread_k="
,
thread_k
,
", thread_n="
,
thread_n
,
", groupsize="
,
groupsize
,
"."
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"mul"
,
&
mul
,
"Marlin FP16xINT4 matmul."
);
m
.
def
(
"gptq_repack"
,
&
gptq_repack
,
"Repack GPTQ checkpoints for Marlin."
);
}
\ No newline at end of file
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_cuda_kernel.cu
0 → 100644
View file @
6a583c2f
/*
* Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MARLIN_CUDA_KERNEL_CUH
#define MARLIN_CUDA_KERNEL_CUH
#include <cuda.h>
#include <cuda_fp16.h>
#include <assert.h>
#include <iostream>
#include "marlin_cuda_kernel.cuh"
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core
// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
using
I4
=
Vec
<
int
,
4
>
;
// Matrix fragments for tensor core instructions; their precise layout is documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that
// are not multiples of 16.
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
)
);
#else
assert
(
0
);
#endif
}
// Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for
// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
// for inputs A and outputs C.
__device__
inline
void
cp_async4_stream
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .b64 p;
\n
"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
)
);
#else
assert
(
0
);
#endif
}
// Async copy fence.
__device__
inline
void
cp_async_fence
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
#else
assert
(
0
);
#endif
}
// Wait until at most `n` async copy stages are still pending.
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
#else
assert
(
0
);
#endif
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
__device__
inline
void
mma
(
const
FragA
&
a_frag
,
const
FragB
&
frag_b
,
FragC
&
frag_c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
])
);
#else
assert
(
0
);
#endif
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
__device__
inline
void
ldsm4
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
)
);
#else
assert
(
0
);
#endif
}
// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
// automatically recognize it in all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
)
);
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.
// We mostly follow the strategy in the link below, with some small changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
)
);
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
)
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
half2
s
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
#else
assert
(
0
);
#endif
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
#else
assert
(
0
);
#endif
}
template
<
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m dimension (batchsize) of the threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
s
,
// fp16 quantization scales of shape (k/groupsize)xn
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple
// column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs
// for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as
// possible.
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
ceildiv
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
// Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case
// where a stripe starts in the middle of group.
if
(
group_blocks
!=
-
1
)
iters
=
(
group_blocks
/
thread_k_blocks
)
*
ceildiv
(
iters
,
(
group_blocks
/
thread_k_blocks
));
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to top
// We can easily implement parallel problem execution by just remapping indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
}
// Compute all information about the current slice which is required for synchronization.
auto
init_slice
=
[
&
]
()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
ceildiv
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
ceildiv
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
int
a_gl_stride
=
prob_k
/
8
;
// stride of the A matrix in global memory
// We typically use `constexpr` to indicate that this value is a compile-time constant
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between subsequent accesses within a tile
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// within a shared memory tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// overall size of a tile
constexpr
int
a_sh_wr_iters
=
ceildiv
(
a_sh_stage
,
a_sh_wr_delta
);
// number of shared write iterations for a tile
int
b_gl_stride
=
16
*
prob_n
/
32
;
constexpr
int
b_sh_stride
=
32
*
thread_n_blocks
/
4
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride
);
constexpr
int
b_sh_wr_delta
=
threads
;
constexpr
int
b_sh_rd_delta
=
threads
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_sh_stage
=
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
threadIdx
.
x
%
b_sh_stride
);
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
;
int
b_sh_rd
=
threadIdx
.
x
;
int
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
int
s_sh_rd
;
// We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major
// layout in the former and in row-major in the latter case.
if
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Precompute which thread should not read memory in which iterations; this is needed if there are more threads than
// required for a certain tilesize or when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank
// conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of
// the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based
// on NSight-Compute) that each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
]
(
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory
// accesses are static, we simply precompute both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between
// subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_s
=
sh_b
+
(
stages
*
b_sh_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// Zero accumulators.
auto
zero_accums
=
[
&
]
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
// Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location.
auto
fetch_to_shared
=
[
&
]
(
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]
);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
cp_async4_stream
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
],
B_ptr
[
i
]);
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
// Only fetch scales if this tile starts a new group
if
(
group_blocks
!=
-
1
&&
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
(
s_sh_wr_pred
)
cp_async4_stream
(
&
sh_s_stage
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
s_gl_rd
+=
s_gl_rd_delta
;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point.
cp_async_fence
();
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]
()
{
// We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when
// it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe into the current register buffer.
auto
fetch_to_registers
=
[
&
]
(
int
k
,
int
pipe
)
{
// It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a
// significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the
// compiler and correspondingly a noticable drop in performance.
if
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
frag_b_quant
[
k
%
2
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
]);
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
]
(
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
FragB
frag_b0
=
dequant
(
b_quant
);
// If there are no groups, we can just scale the final output once and can avoid doing so for each weight.
if
(
group_blocks
!=
-
1
)
scale
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
FragB
frag_b1
=
dequant
(
b_quant_shift
);
if
(
group_blocks
!=
-
1
)
scale
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n
// dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]
()
{
constexpr
int
red_off
=
threads
/
b_sh_stride
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride
;
constexpr
int
red_sh_stride
=
b_sh_stride
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
threadIdx
.
x
%
b_sh_stride
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations,
// e.g., for two warps we write only once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over
// the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
]
(
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step.
// To do this, we write out results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,
// hence we also use async-copies even though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
__half2float
(
reinterpret_cast
<
__half
*>
(
&
c_red
)[
j
]
);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
__half
*>
(
&
c
)[
j
]
=
__float2half
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
);
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step,
// the reduction above is performed in fragment layout.
auto
write_result
=
[
&
]
()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final global write patterns
auto
write
=
[
&
]
(
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
half2
res
=
__halves2half2
(
__float2half
(
c0
),
__float2half
(
c1
));
if
(
group_blocks
==
-
1
)
// for per-column quantization we finally apply the scale here
res
=
__hmul2
(
res
,
s
[
0
]);
((
half2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
ceildiv
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
zero_accums
();
wait_for_stage
();
fetch_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
};
start_pipes
();
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are
// static. Note that both pipelines have even length meaning that the next iteration will always start at index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
break
;
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
// Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most
// readable, other ways of writing the loop seemed to noticeably worse performance after compliation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before write-out
if
(
group_blocks
==
-
1
&&
last
)
{
if
(
s_sh_wr_pred
)
cp_async4_stream
(
&
sh_s
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
cp_async_fence
();
}
thread_block_reduce
();
if
(
group_blocks
==
-
1
&&
last
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
start_pipes
();
}
}
}
}
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more
// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
const
int
THREADS
=
256
;
const
int
STAGES
=
4
;
// 4 pipeline stages fit into shared memory
const
int
SHARED_MEM
=
96
*
1024
;
// max shared memory on compute capability 8.6 (< 8.0)
#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
else if ( \
thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS \
) { \
cudaFuncSetAttribute( \
Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
SHARED_MEM \
); \
Marlin< \
THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \
><<<blocks, THREADS, SHARED_MEM, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, \
prob_m, prob_n, prob_k, \
locks \
); \
}
const
int
ERR_PROB_SHAPE
=
1
;
const
int
ERR_KERN_SHAPE
=
2
;
int
marlin_cuda
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
groupsize
=
-
1
,
int
dev
=
0
,
cudaStream_t
stream
=
0
,
int
thread_k
=
-
1
,
int
thread_n
=
-
1
,
int
sms
=
-
1
,
int
max_par
=
16
)
{
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
ceildiv
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
if
(
sms
==
-
1
)
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
if
(
thread_k
==
-
1
||
thread_n
==
-
1
)
{
if
(
prob_m
<=
16
)
{
// For small batchizes, better partioning is slightly more important than better compute utilization
thread_k
=
128
;
thread_n
=
128
;
}
else
{
thread_k
=
64
;
thread_n
=
256
;
}
}
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
group_blocks
=
(
groupsize
==
-
1
)
?
-
1
:
groupsize
/
16
;
int
blocks
=
sms
;
if
(
prob_n
%
thread_n
!=
0
||
prob_k
%
thread_k
!=
0
||
(
group_blocks
!=
-
1
&&
prob_k
%
group_blocks
!=
0
))
return
ERR_PROB_SHAPE
;
if
(
prob_m
==
0
||
prob_n
==
0
||
prob_k
==
0
)
return
0
;
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
int
cols
=
prob_n
/
thread_n
;
int
*
locks
=
(
int
*
)
workspace
;
int
ret
=
0
;
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
4
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
if
(
thread_m_blocks
>
4
)
{
// Note that parallel > 1 currently only works for inputs without any padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
64
;
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
64
*
par
;
i
+=
4
*
(
par
-
1
);
thread_m_blocks
=
4
;
}
// For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)
// in our testing, however many more are, in principle, possible.
if
(
false
)
{}
CALL_IF
(
1
,
8
,
8
,
-
1
)
CALL_IF
(
1
,
8
,
8
,
8
)
CALL_IF
(
1
,
16
,
4
,
-
1
)
CALL_IF
(
1
,
16
,
4
,
8
)
CALL_IF
(
2
,
16
,
4
,
-
1
)
CALL_IF
(
2
,
16
,
4
,
8
)
CALL_IF
(
3
,
16
,
4
,
-
1
)
CALL_IF
(
3
,
16
,
4
,
8
)
CALL_IF
(
4
,
16
,
4
,
-
1
)
CALL_IF
(
4
,
16
,
4
,
8
)
else
ret
=
ERR_KERN_SHAPE
;
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
}
return
ret
;
}
#endif
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_cuda_kernel.cuh
0 → 100644
View file @
6a583c2f
#include <cuda.h>
#include <cuda_runtime.h>
int
marlin_cuda
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
groupsize
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
);
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_repack.cu
0 → 100644
View file @
6a583c2f
#include <cuda_runtime.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "marlin_repack.cuh"
__global__
void
gptq_repack_kernel
(
uint32_t
*
in
,
uint32_t
*
out
,
int
m
,
int
n
)
{
uint32_t
row
=
blockIdx
.
x
*
2
;
uint32_t
col
=
blockIdx
.
y
*
64
;
uint32_t
t
=
threadIdx
.
x
;
// marlin packs 4 16x16 blocks one time;
const
int
pad_len
=
18
;
__shared__
uint8_t
block
[
4
][
16
][
pad_len
];
// unpack
int
block_idx
=
t
/
8
;
int
block_offset
=
t
%
8
;
for
(
int
offset
=
block_offset
;
offset
<
16
;
offset
+=
8
)
{
uint32_t
v1
=
in
[
row
*
n
+
col
+
block_idx
*
16
+
offset
];
uint32_t
v2
=
in
[(
row
+
1
)
*
n
+
col
+
block_idx
*
16
+
offset
];
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
+=
1
)
{
block
[
block_idx
][
i
][
offset
]
=
v1
&
0xf
;
v1
>>=
4
;
block
[
block_idx
][
i
+
8
][
offset
]
=
v2
&
0xf
;
v2
>>=
4
;
}
}
// repack
// ref: _get_perms @ https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py
uint32_t
srow
=
(
t
%
4
)
*
2
;
uint32_t
scol
=
t
/
4
;
uint32_t
idx
[
8
][
2
];
idx
[
0
][
0
]
=
srow
;
idx
[
0
][
1
]
=
scol
;
idx
[
1
][
0
]
=
srow
+
8
;
idx
[
1
][
1
]
=
scol
;
idx
[
2
][
0
]
=
srow
;
idx
[
2
][
1
]
=
scol
+
8
;
idx
[
3
][
0
]
=
srow
+
8
;
idx
[
3
][
1
]
=
scol
+
8
;
idx
[
4
][
0
]
=
srow
+
1
;
idx
[
4
][
1
]
=
scol
;
idx
[
5
][
0
]
=
srow
+
9
;
idx
[
5
][
1
]
=
scol
;
idx
[
6
][
0
]
=
srow
+
1
;
idx
[
6
][
1
]
=
scol
+
8
;
idx
[
7
][
0
]
=
srow
+
9
;
idx
[
7
][
1
]
=
scol
+
8
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
+=
1
)
{
uint32_t
v
[
8
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
v
[
j
]
=
block
[
i
][
idx
[
j
][
0
]][
idx
[
j
][
1
]];
}
uint32_t
pack
=
(
v
[
7
]
<<
28
)
|
(
v
[
6
]
<<
24
)
|
(
v
[
5
]
<<
20
)
|
(
v
[
4
]
<<
16
)
|
(
v
[
3
]
<<
12
)
|
(
v
[
2
]
<<
8
)
|
(
v
[
1
]
<<
4
)
|
v
[
0
];
out
[
blockIdx
.
x
*
n
*
2
+
blockIdx
.
y
*
128
+
t
*
4
+
i
]
=
pack
;
}
}
torch
::
Tensor
gptq_repack
(
torch
::
Tensor
W
)
{
int
m
=
W
.
sizes
()[
0
];
int
n
=
W
.
sizes
()[
1
];
assert
(
W
.
is_contiguous
());
assert
(
W
.
dtype
()
==
at
::
kInt
);
assert
(
m
%
2
==
0
);
assert
(
n
%
64
==
0
);
auto
result
=
at
::
empty
(
{
m
/
2
,
n
*
2
},
at
::
TensorOptions
().
dtype
(
at
::
kInt
).
device
(
W
.
device
()));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
W
));
const
dim3
threads
(
32
);
// marlin packs 16 x 64 block and gptq packs 8 x 1
const
dim3
blocks
(
m
/
2
,
n
/
64
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
gptq_repack_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
(
uint32_t
*
)
W
.
data_ptr
(),
(
uint32_t
*
)
result
.
data_ptr
(),
m
,
n
);
return
result
;
}
\ No newline at end of file
3rd_party/AutoGPTQ/autogptq_extension/marlin/marlin_repack.cuh
0 → 100644
View file @
6a583c2f
#include <torch/all.h>
__global__
void
gptq_repack_kernel
(
uint32_t
*
in
,
uint32_t
*
out
,
int
m
,
int
n
);
torch
::
Tensor
gptq_repack
(
torch
::
Tensor
W
);
\ No newline at end of file
3rd_party/AutoGPTQ/autogptq_extension/qigen/generate.py
0 → 100644
View file @
6a583c2f
import
argparse
import
subprocess
import
time
import
numpy
as
np
import
pandas
as
pd
import
template
from
gekko
import
GEKKO
def
mem_model
(
N
,
M
,
T
,
mu
,
tu
,
bits
,
l1
,
p
,
gs
,
verbose
=
False
):
m
=
GEKKO
()
# create GEKKO model
# cinfergen if bits==3:
# tu = tu*3
B
=
m
.
Const
(
value
=
bits
)
TP
=
m
.
Const
(
value
=
T
//
p
)
k
=
m
.
Var
(
1
,
integer
=
True
,
lb
=
1
)
z
=
m
.
Var
(
1
,
integer
=
True
,
lb
=
1
)
w
=
m
.
Var
(
1
,
integer
=
True
,
lb
=
1
)
y
=
m
.
Var
(
1
,
integer
=
True
,
lb
=
1
)
v
=
m
.
Var
(
1
,
integer
=
True
,
lb
=
1
)
mb
=
m
.
Var
(
mu
,
integer
=
True
,
lb
=
1
)
if
gs
!=
-
1
:
gg
=
m
.
Var
(
1
,
integer
=
True
,
lb
=
1
)
tb
=
m
.
Var
(
tu
,
integer
=
True
,
lb
=
1
,
ub
=
int
(
T
/
p
))
L
=
m
.
Var
(
integer
=
True
,
lb
=
0
,
ub
=
l1
)
m
.
Equation
(
L
==
32
*
mb
*
N
+
B
*
mb
*
tb
+
32
*
tb
*
N
)
m
.
Equation
(
mb
*
k
==
M
)
if
gs
!=
-
1
:
m
.
Equation
(
gs
*
gg
==
mb
)
# m.Equation(tb * z == T)
m
.
Equation
(
tb
*
z
==
TP
)
m
.
Equation
(
mu
*
w
==
mb
)
m
.
Equation
(
tu
*
y
==
tb
)
# m.Equation(tb * v == tt)
m
.
Maximize
(
L
)
m
.
options
.
SOLVER
=
1
m
.
solver_options
=
[
"minlp_maximum_iterations 1000"
,
# minlp iterations with integer solution
"minlp_max_iter_with_int_sol 10"
,
# treat minlp as nlp
"minlp_as_nlp 0"
,
# nlp sub-problem max iterations
"nlp_maximum_iterations 100"
,
# 1 = depth first, 2 = breadth first
"minlp_branch_method 2"
,
# maximum deviation from whole number
"minlp_integer_tol 0.00"
,
# covergence tolerance
"minlp_gap_tol 0.01"
,
]
try
:
m
.
solve
(
disp
=
False
)
except
:
try
:
m
.
solver_options
=
[
"minlp_maximum_iterations 1000"
,
# minlp iterations with integer solution
"minlp_max_iter_with_int_sol 10"
,
# treat minlp as nlp
"minlp_as_nlp 0"
,
# nlp sub-problem max iterations
"nlp_maximum_iterations 100"
,
# 1 = depth first, 2 = breadth first
"minlp_branch_method 1"
,
# maximum deviation from whole number
"minlp_integer_tol 0.00"
,
# covergence tolerance
"minlp_gap_tol 0.01"
,
]
m
.
solve
(
disp
=
False
)
except
:
# mytb = T//p
mytb
=
tu
if
gs
!=
-
1
:
mymb
=
gs
while
32
*
(
mymb
+
gs
)
*
N
+
bits
*
(
mymb
+
gs
)
*
mytb
+
32
*
mytb
*
N
<
l1
:
mymb
+=
gs
while
M
%
mymb
!=
0
:
mymb
-=
gs
if
verbose
:
print
(
"Failed to solve, using heuristic. mb = "
,
mymb
,
"tb = "
,
mytb
)
return
(
int
(
mymb
),
int
(
mytb
))
else
:
mymb
=
mu
while
32
*
(
mymb
+
mu
)
*
N
+
bits
*
(
mymb
+
mu
)
*
mytb
+
32
*
mytb
*
N
<
l1
:
mymb
+=
mu
while
M
%
mymb
!=
0
:
mymb
-=
mu
if
verbose
:
print
(
"Failed to solve, using heuristic. mb = "
,
mymb
,
"tb = "
,
mytb
)
return
(
int
(
mymb
),
int
(
mytb
))
if
verbose
:
print
(
"mb = "
,
int
(
mb
.
value
[
0
]),
"tb = "
,
int
(
tb
.
value
[
0
]))
return
(
int
(
mb
.
value
[
0
]),
int
(
tb
.
value
[
0
]))
def
macros
():
return
"#include<omp.h>
\n
#include<cstdint>
\n
#include<immintrin.h>
\n
#include<fstream>
\n\n
#define mymin(a,b) ((a)<(b)?(a):(b))
\n
#define mymax(a,b) ((a)>(b)?(a):(b))
\n
"
def
print_parameters
(
bits
,
n
,
m
,
t
,
nb
,
mb
,
tb
,
mu
,
nu
,
tu
,
unroll
,
p
,
gs
=-
1
):
res
=
""
res
+=
"void print_parameters(){
\n
"
res
+=
f
' std::cout <<
{
bits
}
<< "bits," <<
{
n
}
<< "," <<
{
m
}
<< "," <<
{
t
}
<< "," <<
{
nb
}
<< "," <<
{
mb
}
<< "," <<
{
tb
}
<< "," <<
{
nu
}
<< "," <<
{
mu
}
<< "," <<
{
tu
}
<< "," <<
{
unroll
}
<< "," <<
{
p
}
<< "," <<
{
gs
}
<< ",";
\n
'
res
+=
"}
\n
"
return
res
def
print_parameters_module
(
bits
,
mu
,
nu
,
tu
,
unroll
,
p
,
gs
=-
1
):
res
=
""
res
+=
"void print_parameters(){
\n
"
res
+=
"std::ofstream outfile;
\n
"
res
+=
'outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
\n
'
res
+=
f
'outfile <<
{
bits
}
<< "," <<
{
nu
}
<< "," <<
{
mu
}
<< "," <<
{
tu
}
<< "," <<
{
unroll
}
<< "," <<
{
p
}
<< "," <<
{
gs
}
<< ",";
\n
'
res
+=
"}
\n
"
return
res
def
pack_in
(
n
,
m
,
nb
,
mb
):
res
=
""
res
+=
"inline void pack_input(float* A, float* B){
\n
"
res
+=
" // copy the full matrix A in blocked format into B
\n
"
res
+=
" uint64_t idx = 0;
\n
"
res
+=
f
" const int N =
{
n
}
;
\n
"
res
+=
f
" const int M =
{
m
}
;
\n
"
res
+=
f
" const int nb =
{
nb
}
;
\n
"
res
+=
f
" const int mb =
{
mb
}
;
\n
"
res
+=
" for(int i = 0; i < N; i+=nb){
\n
\
for(int j = 0; j < M; j+=mb){
\n
\
for(int jj = j; jj < mymin(j+mb, M); jj++){
\n
\
for(int ii = i; ii < mymin(i+nb, N); ii++){
\n
\
B[idx] = A[ii*M+jj];
\n
\
idx++;
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
"
return
res
def
pack_out
(
n
,
t
,
nb
,
tb
):
res
=
""
res
+=
"inline void pack_output(float* A, float* B){
\n
"
res
+=
" // copy the full matrix A in blocked format into B
\n
"
res
+=
" uint64_t idx = 0;
\n
"
res
+=
f
" const int N =
{
n
}
;
\n
"
res
+=
f
" const int M =
{
t
}
;
\n
"
res
+=
f
" const int nb =
{
nb
}
;
\n
"
res
+=
f
" const int mb =
{
tb
}
;
\n
"
res
+=
" for(int i = 0; i < N; i+=nb){
\n
\
for(int j = 0; j < M; j+=mb){
\n
\
for(int ii = i; ii < mymin(i+nb, N); ii++){
\n
\
for(int jj = j; jj < mymin(j+mb, M); jj++){
\n
\
B[idx] = A[ii*M+jj];
\n
\
idx++;
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
"
return
res
def
pack_qw
(
m
,
t
,
mb
,
tb
,
tb1
,
bits
=
4
,
cutoff
=-
1
):
packed
=
32
//
bits
res
=
""
if
cutoff
==
-
1
:
cutoff
=
65
if
bits
==
3
:
res
+=
"inline void pack_qw_inner(int* A, int* B, int cutoff){
\n
"
res
+=
" // copy the full matrix A in blocked format into B
\n
"
res
+=
" uint64_t idx = 0;
\n
"
res
+=
f
" const int N =
{
m
//
32
*
3
}
;
\n
"
res
+=
f
" const int M =
{
t
}
;
\n
"
res
+=
f
" const int nb =
{
mb
//
32
*
3
}
;
\n
"
res
+=
f
"int mb =
{
int
(
tb
)
}
;
\n
"
res
+=
" for(int j = 0, tid = 0; j < M; j+=mb, tid++){
\n
"
# res += "if(tid==cutoff){\n "
# res += f" mb = {tb1};\n"
# res += "}\n"
res
+=
" for(int i = 0; i < N; i+=nb){
\n
\
for(int ii = i; ii < mymin(i+nb, N); ii+=3){
\n
\
for(int jj = j; jj < mymin(j+mb, M); jj+=8){
\n
\
for(int iii = ii; iii < ii + 3; iii++){
\n
\
for(int jjj = jj; jjj < jj + 8; jjj++){
\n
\
B[idx] = A[iii*M+jjj];
\n
\
idx++;
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
"
res
+=
"inline void pack_qw(int* A, int* B){
\n
"
res
+=
f
" pack_qw_inner(A, B,
{
cutoff
}
);
\n
"
res
+=
"}
\n
"
return
res
else
:
# in case i do this for python i can just add the n,m,nb,mb as function parameters
res
+=
"inline void pack_qw_inner(int* A, int* B, int cutoff){
\n
"
res
+=
" // copy the full matrix A in blocked format into B
\n
"
res
+=
" uint64_t idx = 0;
\n
"
res
+=
f
" const int N =
{
m
//
packed
}
;
\n
"
res
+=
f
" const int M =
{
t
}
;
\n
"
res
+=
f
" const int nb =
{
mb
//
packed
}
;
\n
"
res
+=
f
"int mb =
{
int
(
tb
)
}
;
\n
"
res
+=
" for(int j = 0, tid = 0; j < M; j+=mb, tid++){
\n
"
# res += "if(tid==cutoff){\n "
# res += f" mb = {tb1};\n"
# res += "}\n"
res
+=
" for(int i = 0; i < N; i+=nb){
\n
\
for(int ii = i; ii < mymin(i+nb, N); ii++){
\n
\
for(int jj = j; jj < mymin(j+mb, M); jj++){
\n
\
B[idx] = A[ii*M+jj];
\n
\
idx++;
\n
\
}
\n
\
}
\n
\
}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"inline void pack_qw(int* A, int* B){
\n
"
res
+=
f
" pack_qw_inner(A, B,
{
cutoff
}
);
\n
"
res
+=
"}
\n
"
return
res
def
block_gs
(
nu_iter
,
mu
,
tu
,
rho
,
packed
,
unroll
,
bits
):
res
=
""
i
=
0
# unroll = 4 # number of bcasts and unpacks
if
bits
==
3
:
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i w0_
{
j
}
= _mm256_loadu_si256((__m256i*)&W[base_W + j*m/
{
packed
}
*3 + k*mb*tb/
{
packed
}
*3 + k3*tb/
{
packed
}
*3 + jw+
{
j
*
3
}
]);
\n
"
res
+=
f
"__m256i w1_
{
j
}
= _mm256_loadu_si256((__m256i*)&W[base_W + j*m/
{
packed
}
*3 + k*mb*tb/
{
packed
}
*3 + k3*tb/
{
packed
}
*3 + jw+
{
j
*
3
}
+8]);
\n
"
res
+=
f
"__m256i w2_
{
j
}
= _mm256_loadu_si256((__m256i*)&W[base_W + j*m/
{
packed
}
*3 + k*mb*tb/
{
packed
}
*3 + k3*tb/
{
packed
}
*3 + jw+
{
j
*
3
}
+16]);
\n
"
u
=
0
first_off
=
3
second_off
=
2
wid
=
0
shift
=
0
while
u
<
32
:
if
u
==
10
:
res
+=
f
"__m256 v
{
i
}
_
{
u
}
= _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+
{
u
}
)*nb + i1+
{
i
}
]);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i ws
{
j
}
_10 = _mm256_srli_epi32(w0_
{
j
}
,
{
bits
*
10
}
);
\n
"
res
+=
f
"__m256i temp0_
{
j
}
= _mm256_slli_epi32(w1_
{
j
}
, 2);
\n
"
res
+=
f
"temp0_
{
j
}
= _mm256_and_si256(temp0_
{
j
}
, mask);
\n
"
res
+=
f
"ws
{
j
}
_10 = _mm256_or_si256(ws
{
j
}
_10, temp0_
{
j
}
);
\n
"
res
+=
f
"__m256i wsa
{
j
}
_
{
u
}
= _mm256_and_si256(ws
{
j
}
_
{
u
}
, mask);
\n
"
res
+=
f
"__m256 l
{
j
}
_
{
u
}
= _mm256_cvtepi32_ps(wsa
{
j
}
_
{
u
}
);
\n
"
res
+=
f
"acc
{
i
}
_
{
j
}
= _mm256_fmadd_ps(v
{
i
}
_
{
u
}
, l
{
j
}
_
{
u
}
, acc
{
i
}
_
{
j
}
);
\n
"
wid
=
wid
+
1
u
=
u
+
1
elif
u
==
21
:
res
+=
f
"__m256 v
{
i
}
_
{
u
}
= _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+
{
u
}
)*nb + i1+
{
i
}
]);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i ws
{
j
}
_
{
u
}
= _mm256_srli_epi32(w1_
{
j
}
, 31);
\n
"
res
+=
f
"__m256i temp1_
{
j
}
= _mm256_slli_epi32(w2_
{
j
}
, 1);
\n
"
res
+=
f
"temp1_
{
j
}
= _mm256_and_si256(temp1_
{
j
}
, mask);
\n
"
res
+=
f
"ws
{
j
}
_
{
u
}
= _mm256_or_si256(ws
{
j
}
_
{
u
}
, temp1_
{
j
}
);
\n
"
res
+=
f
"__m256i wsa
{
j
}
_
{
u
}
= _mm256_and_si256(ws
{
j
}
_
{
u
}
, mask);
\n
"
res
+=
f
"__m256 l
{
j
}
_
{
u
}
= _mm256_cvtepi32_ps(wsa
{
j
}
_
{
u
}
);
\n
"
res
+=
f
"acc
{
i
}
_
{
j
}
= _mm256_fmadd_ps(v
{
i
}
_
{
u
}
, l
{
j
}
_
{
u
}
, acc
{
i
}
_
{
j
}
);
\n
"
wid
=
wid
+
1
u
=
u
+
1
for
k
in
range
(
u
,
u
+
second_off
):
res
+=
f
"__m256 v
{
i
}
_
{
k
}
= _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+
{
k
}
)*nb + i1+
{
i
}
]);
\n
"
for
k
in
range
(
u
,
u
+
second_off
):
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i ws
{
j
}
_
{
k
}
= _mm256_srli_epi32(w
{
wid
}
_
{
j
}
,
{
bits
*
k
-
wid
*
32
-
shift
}
);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i wsa
{
j
}
_
{
k
}
= _mm256_and_si256(ws
{
j
}
_
{
k
}
, mask);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 l
{
j
}
_
{
k
}
= _mm256_cvtepi32_ps(wsa
{
j
}
_
{
k
}
);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"acc
{
i
}
_
{
j
}
= _mm256_fmadd_ps(v
{
i
}
_
{
k
}
, l
{
j
}
_
{
k
}
, acc
{
i
}
_
{
j
}
);
\n
"
u
=
u
+
2
return
res
else
:
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i w
{
j
}
= _mm256_loadu_si256((__m256i*)&W[base_W + j*m/
{
packed
}
+ k*mb*tb/
{
packed
}
+ k3*tb/
{
packed
}
+ j1+
{
j
}
]);
\n
"
for
u
in
range
(
packed
-
unroll
,
-
1
,
-
unroll
):
for
k
in
range
(
u
+
unroll
-
1
,
u
-
1
,
-
1
):
res
+=
f
"__m256 v
{
i
}
_
{
k
}
= _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+
{
k
}
)*nb + i1+
{
i
}
]);
\n
"
for
k
in
range
(
u
,
u
+
unroll
):
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i ws
{
j
}
_
{
k
}
= _mm256_srli_epi32(w
{
j
}
,
{
bits
*
k
}
);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i wsa
{
j
}
_
{
k
}
= _mm256_and_si256(ws
{
j
}
_
{
k
}
, mask);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 l
{
j
}
_
{
k
}
= _mm256_cvtepi32_ps(wsa
{
j
}
_
{
k
}
);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"acc
{
i
}
_
{
j
}
= _mm256_fmadd_ps(v
{
i
}
_
{
k
}
, l
{
j
}
_
{
k
}
, acc
{
i
}
_
{
j
}
);
\n
"
return
res
def
block
(
nu_iter
,
mu
,
tu
,
rho
,
packed
,
unroll
,
bits
):
res
=
""
i
=
0
# unroll = 4 # number of bcasts and unpacks
if
bits
==
3
:
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i w0_
{
j
}
= _mm256_loadu_si256((__m256i*)&W[base_W + j*m/
{
packed
}
*3 + k*mb*tb/
{
packed
}
*3 + k2*tb/
{
packed
}
*3 + jw+
{
j
*
3
}
]);
\n
"
res
+=
f
"__m256i w1_
{
j
}
= _mm256_loadu_si256((__m256i*)&W[base_W + j*m/
{
packed
}
*3 + k*mb*tb/
{
packed
}
*3 + k2*tb/
{
packed
}
*3 + jw+
{
j
*
3
}
+8]);
\n
"
res
+=
f
"__m256i w2_
{
j
}
= _mm256_loadu_si256((__m256i*)&W[base_W + j*m/
{
packed
}
*3 + k*mb*tb/
{
packed
}
*3 + k2*tb/
{
packed
}
*3 + jw+
{
j
*
3
}
+16]);
\n
"
u
=
0
first_off
=
3
second_off
=
2
wid
=
0
shift
=
0
while
u
<
32
:
if
u
==
10
:
res
+=
f
"__m256 v
{
i
}
_
{
u
}
= _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+
{
u
}
)*nb + i1+
{
i
}
]);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i ws
{
j
}
_10 = _mm256_srli_epi32(w0_
{
j
}
,
{
bits
*
10
}
);
\n
"
res
+=
f
"__m256i temp0_
{
j
}
= _mm256_slli_epi32(w1_
{
j
}
, 2);
\n
"
res
+=
f
"temp0_
{
j
}
= _mm256_and_si256(temp0_
{
j
}
, mask);
\n
"
res
+=
f
"ws
{
j
}
_10 = _mm256_or_si256(ws
{
j
}
_10, temp0_
{
j
}
);
\n
"
res
+=
f
"__m256i wsa
{
j
}
_
{
u
}
= _mm256_and_si256(ws
{
j
}
_
{
u
}
, mask);
\n
"
res
+=
f
"__m256 l
{
j
}
_
{
u
}
= _mm256_cvtepi32_ps(wsa
{
j
}
_
{
u
}
);
\n
"
res
+=
f
"acc
{
i
}
_
{
j
}
= _mm256_fmadd_ps(v
{
i
}
_
{
u
}
, l
{
j
}
_
{
u
}
, acc
{
i
}
_
{
j
}
);
\n
"
wid
=
wid
+
1
u
=
u
+
1
elif
u
==
21
:
res
+=
f
"__m256 v
{
i
}
_
{
u
}
= _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+
{
u
}
)*nb + i1+
{
i
}
]);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i ws
{
j
}
_
{
u
}
= _mm256_srli_epi32(w1_
{
j
}
, 31);
\n
"
res
+=
f
"__m256i temp1_
{
j
}
= _mm256_slli_epi32(w2_
{
j
}
, 1);
\n
"
res
+=
f
"temp1_
{
j
}
= _mm256_and_si256(temp1_
{
j
}
, mask);
\n
"
res
+=
f
"ws
{
j
}
_
{
u
}
= _mm256_or_si256(ws
{
j
}
_
{
u
}
, temp1_
{
j
}
);
\n
"
res
+=
f
"__m256i wsa
{
j
}
_
{
u
}
= _mm256_and_si256(ws
{
j
}
_
{
u
}
, mask);
\n
"
res
+=
f
"__m256 l
{
j
}
_
{
u
}
= _mm256_cvtepi32_ps(wsa
{
j
}
_
{
u
}
);
\n
"
res
+=
f
"acc
{
i
}
_
{
j
}
= _mm256_fmadd_ps(v
{
i
}
_
{
u
}
, l
{
j
}
_
{
u
}
, acc
{
i
}
_
{
j
}
);
\n
"
wid
=
wid
+
1
u
=
u
+
1
for
k
in
range
(
u
,
u
+
second_off
):
res
+=
f
"__m256 v
{
i
}
_
{
k
}
= _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+
{
k
}
)*nb + i1+
{
i
}
]);
\n
"
for
k
in
range
(
u
,
u
+
second_off
):
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i ws
{
j
}
_
{
k
}
= _mm256_srli_epi32(w
{
wid
}
_
{
j
}
,
{
bits
*
k
-
wid
*
32
-
shift
}
);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i wsa
{
j
}
_
{
k
}
= _mm256_and_si256(ws
{
j
}
_
{
k
}
, mask);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 l
{
j
}
_
{
k
}
= _mm256_cvtepi32_ps(wsa
{
j
}
_
{
k
}
);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"acc
{
i
}
_
{
j
}
= _mm256_fmadd_ps(v
{
i
}
_
{
k
}
, l
{
j
}
_
{
k
}
, acc
{
i
}
_
{
j
}
);
\n
"
u
=
u
+
2
return
res
else
:
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i w
{
j
}
= _mm256_loadu_si256((__m256i*)&W[base_W + j*m/
{
packed
}
+ k*mb*tb/
{
packed
}
+ k2*tb/
{
packed
}
+ j1+
{
j
}
]);
\n
"
for
u
in
range
(
packed
-
unroll
,
-
1
,
-
unroll
):
for
k
in
range
(
u
+
unroll
-
1
,
u
-
1
,
-
1
):
res
+=
f
"__m256 v
{
i
}
_
{
k
}
= _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+
{
k
}
)*nb + i1+
{
i
}
]);
\n
"
for
k
in
range
(
u
,
u
+
unroll
):
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i ws
{
j
}
_
{
k
}
= _mm256_srli_epi32(w
{
j
}
,
{
bits
*
k
}
);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256i wsa
{
j
}
_
{
k
}
= _mm256_and_si256(ws
{
j
}
_
{
k
}
, mask);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 l
{
j
}
_
{
k
}
= _mm256_cvtepi32_ps(wsa
{
j
}
_
{
k
}
);
\n
"
for
j
in
range
(
0
,
tu
,
8
):
res
+=
f
"acc
{
i
}
_
{
j
}
= _mm256_fmadd_ps(v
{
i
}
_
{
k
}
, l
{
j
}
_
{
k
}
, acc
{
i
}
_
{
j
}
);
\n
"
return
res
def
accumulators_f
(
nu
,
tu
,
gs
=
False
):
accumulators
=
""
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
if
gs
:
accumulators
+=
f
"__m256 acc
{
i
}
_
{
j
}
= _mm256_setzero_ps();
\n
"
else
:
accumulators
+=
(
f
"__m256 acc
{
i
}
_
{
j
}
= _mm256_loadu_ps(&output[base_output + j + (i1+
{
i
}
)*t + j1+
{
j
}
]);
\n
"
)
return
accumulators
def
stores_f
(
nu
,
tu
,
gs
=
False
):
store
=
""
if
gs
:
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"__m256 o
{
i
}
_
{
j
}
= _mm256_loadu_ps(&output[base_output + j + (i1+
{
i
}
)*t + j1+
{
j
}
]);
\n
"
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"__m256 s
{
i
}
_
{
j
}
= _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+
{
j
}
]);
\n
"
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"__m256 f
{
i
}
_
{
j
}
= _mm256_fmadd_ps(acc
{
i
}
_
{
j
}
, s
{
i
}
_
{
j
}
, o
{
i
}
_
{
j
}
);
\n
"
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"_mm256_storeu_ps(&output[base_output + j + (i1+
{
i
}
)*t + j1+
{
j
}
], f
{
i
}
_
{
j
}
);
\n
"
else
:
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"_mm256_storeu_ps(&output[base_output + j + (i1+
{
i
}
)*t + j1+
{
j
}
], acc
{
i
}
_
{
j
}
);
\n
"
return
store
def
qforward
(
nu
,
mu
,
tu
,
p
,
unroll
,
bits
,
n
=
0
,
m
=
0
,
t
=
0
,
nb
=
0
,
mb
=
0
,
tb
=
0
,
tt
=
0
,
cutoff
=-
1
,
gs
=
False
,
gs_val
=-
1
,
module
=
True
,
):
assert
module
or
(
gs
and
gs_val
!=
-
1
)
or
(
not
gs
and
gs_val
==
-
1
)
if
cutoff
==
-
1
:
cutoff
=
p
+
1
# packed = 32 // bits
if
bits
==
3
:
packed
=
32
loopguard
=
packed
else
:
packed
=
32
//
bits
loopguard
=
packed
# compute the parameters from the model
accumulators
=
accumulators_f
(
nu
,
tu
,
gs
)
store
=
stores_f
(
nu
,
tu
,
gs
)
ugemm
=
""
if
gs
:
ugemm
+=
"int j1 = 0;
\n
"
if
bits
==
3
:
ugemm
+=
"int jw = 0;
\n
"
ugemm
+=
f
"for(; j1 < tb-tu+1; j1+=tu, jw+=
{
tu
*
3
}
)"
ugemm
+=
"{
\n
"
else
:
ugemm
+=
"for(; j1 < tb-tu+1; j1+=tu) {
\n
"
ugemm
+=
"for(int k1 = 0; k1 < mb; k1+=gs) {
\n
"
ugemm
+=
accumulators
ugemm
+=
f
"for(int k2 = k1; k2 < k1+gs; k2+=
{
loopguard
}
)
\n
"
ugemm
+=
"{
\n
"
ugemm
+=
block
(
nu
,
mu
,
tu
,
16
,
packed
,
unroll
,
bits
)
ugemm
+=
"}
\n
"
ugemm
+=
store
ugemm
+=
"}
\n
"
ugemm
+=
"}
\n
"
else
:
ugemm
+=
"int j1 = 0;
\n
"
if
bits
==
3
:
ugemm
+=
"int jw = 0;
\n
"
ugemm
+=
f
"for(; j1 < tb-tu+1; j1+=tu, jw+=
{
tu
*
3
}
)"
ugemm
+=
"{
\n
"
else
:
ugemm
+=
"for(; j1 < tb-tu+1; j1+=tu) {
\n
"
ugemm
+=
accumulators
ugemm
+=
"for(int k1 = 0; k1 < mb; k1+=mu) {
\n
"
ugemm
+=
f
"for(int k2 = k1; k2 < k1+mu; k2+=
{
loopguard
}
)"
ugemm
+=
"{
\n
"
ugemm
+=
block
(
nu
,
mu
,
tu
,
16
,
packed
,
unroll
,
bits
)
ugemm
+=
"}
\n
"
ugemm
+=
"}
\n
"
ugemm
+=
store
ugemm
+=
"}
\n
"
res
=
""
res
+=
"inline
\n
"
if
gs
:
res
+=
f
"void q
{
bits
}
gemm_gs(const float* __restrict__ input,
\n
"
else
:
res
+=
f
"void q
{
bits
}
gemm(const float* __restrict__ input,
\n
"
res
+=
"const int* __restrict__ W,
\n
"
res
+=
"const float* __restrict__ scales,
\n
"
res
+=
"const float* __restrict__ zeros,
\n
"
res
+=
"const float* __restrict__ bias,
\n
"
res
+=
"const float* __restrict__ sums,
\n
"
res
+=
"float* __restrict__ output,
\n\
const int n,
\n\
const int m,
\n\
const int t,
\n\
const int nb,
\n\
const int mb,
\n\
const int tb,
\n\
int ogtt,
\n
"
if
gs
:
res
+=
"const int gs,
\n
"
res
+=
"const int cutoff){
\n
"
res
+=
f
"#pragma omp parallel num_threads(
{
p
}
)
\n
"
res
+=
"{
\n
"
res
+=
"int tid;
\n
"
res
+=
f
"const int mu =
{
mu
}
;
\n
"
res
+=
f
"const int nu =
{
nu
}
;
\n
"
res
+=
f
"const int tu =
{
tu
}
;
\n
"
res
+=
"const int on = n / nb;
\n
"
res
+=
"const int om = m / mb;
\n
"
mask
=
(
2
**
bits
)
-
1
res
+=
f
"const __m256i mask = _mm256_set1_epi32(
{
mask
}
);
\n
"
if
bits
==
3
:
res
+=
"const __m256i mask4 = _mm256_set1_epi32(4);
\n
"
res
+=
"const __m256i mask6 = _mm256_set1_epi32(6);
\n
"
res
+=
"tid = omp_get_thread_num();
\n
"
res
+=
"int tt = ogtt;
\n
"
res
+=
"if(tid >= cutoff){
\n
"
res
+=
"tt -= tb;
\n
"
res
+=
"}
\n
"
res
+=
"const int base_output = tid >= cutoff ?
\n
\
(tid-cutoff)*tt + (tt+tb)*cutoff:
\n
\
tid*tt;
\n
"
# is this >= cutoff or > cutoff?
if
bits
!=
3
:
res
+=
f
"const int base_W = tid >= cutoff ?
\n
\
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/
{
packed
}
:
\n
\
tid*tt*m/
{
packed
}
;
\n
"
else
:
res
+=
f
"const int base_W = tid >= cutoff ?
\n
\
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/
{
packed
}
*3:
\n
\
tid*tt*m/
{
packed
}
*3;
\n
"
res
+=
"for(int j = 0; j < tt; j+=tb){
\n
"
res
+=
"for(int i = 0; i < on; i++) {
\n
"
res
+=
"for(int k = 0; k < om; k++) {
\n
"
res
+=
"for(int i1 = 0; i1 < nb; i1+=nu) {
\n
"
res
+=
ugemm
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"#pragma omp barrier
\n
"
# res += "#pragma omp for\n"
if
gs
:
res
+=
"const int ngs = m/gs;
\n
"
res
+=
"for (int i = 0; i < n; i++) {
\n
"
res
+=
f
"for (int j = 0; j < tt; j+=
{
tu
}
)"
res
+=
"{
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 acc
{
i
}
= _mm256_setzero_ps();
\n
"
res
+=
"for (int i1 = 0; i1 < ngs; i1++){
\n
"
res
+=
"__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 z
{
i
}
= _mm256_loadu_ps(&zeros[base_output + i1* t + j +
{
i
}
]);
\n
"
# if not module:
if
bits
!=
3
or
not
module
:
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 s
{
i
}
= _mm256_loadu_ps(&scales[base_output + i1 * t + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 zs
{
i
}
= _mm256_mul_ps(z
{
i
}
, s
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
# if module:
if
bits
==
3
and
module
:
res
+=
f
"acc
{
i
}
= _mm256_fmadd_ps(z
{
i
}
, r, acc
{
i
}
);
\n
"
else
:
res
+=
f
"acc
{
i
}
= _mm256_fmadd_ps(zs
{
i
}
, r, acc
{
i
}
);
\n
"
res
+=
"}
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 o
{
i
}
= _mm256_loadu_ps(&output[i*t + base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 b
{
i
}
= _mm256_loadu_ps(&bias[base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
if
module
:
res
+=
f
"__m256 o1
{
i
}
= _mm256_sub_ps(o
{
i
}
, acc
{
i
}
);
\n
"
else
:
res
+=
f
"__m256 o1
{
i
}
= _mm256_add_ps(o
{
i
}
, acc
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 o2
{
i
}
= _mm256_add_ps(o1
{
i
}
, b
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"_mm256_storeu_ps(&output[i*t + base_output + j +
{
i
}
], o2
{
i
}
);
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
else
:
res
+=
"for (int i = 0; i < n; i++) {
\n
"
res
+=
"__m256 r = _mm256_set1_ps(sums[i]);
\n
"
res
+=
f
"for (int j = 0; j < tt; j+=
{
tu
}
)"
res
+=
"{
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 o
{
i
}
= _mm256_loadu_ps(&output[i*t + base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 z
{
i
}
= _mm256_loadu_ps(&zeros[base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 b
{
i
}
= _mm256_loadu_ps(&bias[base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 s
{
i
}
= _mm256_loadu_ps(&scales[base_output + j +
{
i
}
]);
\n
"
if
bits
==
3
and
module
:
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 os
{
i
}
= _mm256_mul_ps(o
{
i
}
, s
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
if
module
:
if
bits
==
3
:
res
+=
f
"__m256 zr
{
i
}
= _mm256_fnmadd_ps(z
{
i
}
, r, os
{
i
}
);
\n
"
else
:
res
+=
f
"__m256 zr
{
i
}
= _mm256_fnmadd_ps(z
{
i
}
, r, o
{
i
}
);
\n
"
else
:
res
+=
f
"__m256 zr
{
i
}
= _mm256_fmadd_ps(z
{
i
}
, r, o
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
# j res += f"__m256 o2{i} = _mm256_mul_ps(zr{i}, s{i});\n"
if
bits
==
3
and
module
:
res
+=
f
"__m256 o2
{
i
}
= _mm256_add_ps(zr
{
i
}
, b
{
i
}
);
\n
"
else
:
res
+=
f
"__m256 o2
{
i
}
= _mm256_fmadd_ps(zr
{
i
}
, s
{
i
}
, b
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"_mm256_storeu_ps(&output[i*t + base_output + j +
{
i
}
], o2
{
i
}
);
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
# wrapper for qgemm if we call from cpp
if
module
:
if
gs
:
res
+=
f
"inline void forward
{
bits
}
_gs_cpu(
\n
"
else
:
res
+=
f
"inline void forward
{
bits
}
_cpu(
\n
"
res
+=
"torch::Tensor in, torch::Tensor weight, torch::Tensor out,
\n
"
res
+=
"torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,
\n
"
if
gs
:
res
+=
"int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){
\n
"
else
:
res
+=
"int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){
\n
"
res
+=
"int* W = weight.data_ptr<int>();
\n
"
res
+=
"float* input = in.data_ptr<float>();
\n
"
res
+=
"float* b = bias.data_ptr<float>();
\n
"
res
+=
"float* s = scales.data_ptr<float>();
\n
"
res
+=
"float* z = zeros.data_ptr<float>();
\n
"
res
+=
"float* r = sums.data_ptr<float>();
\n
"
res
+=
"float* O = out.data_ptr<float>();
\n
"
res
+=
"
\n
"
if
gs
:
res
+=
f
"q
{
bits
}
gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff);
\n
"
else
:
res
+=
f
"q
{
bits
}
gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff);
\n
"
res
+=
"}
\n
"
else
:
res
+=
"inline void qforward(const float* __restrict__ input,
\n
\
const int* __restrict__ W,
\n\
const float* __restrict__ scales,
\n\
const float* __restrict__ zeros,
\n\
const float* __restrict__ bias,
\n\
const float* __restrict__ sums,
\n\
float* __restrict__ output,
\n\
int n,
\n
\
int m,
\n
\
int t) {
\n
"
if
gs
:
res
+=
f
"q
{
bits
}
gemm_gs(input, W, scales, zeros, bias, sums, output, n, m, t,
{
nb
}
,
{
mb
}
,
{
tb
}
,
{
tt
}
,
{
gs_val
}
,
{
cutoff
}
);
\n
"
else
:
res
+=
f
"q
{
bits
}
gemm(input, W, scales, zeros, bias, sums, output, n, m, t,
{
nb
}
,
{
mb
}
,
{
tb
}
,
{
tt
}
,
{
cutoff
}
);
\n
"
res
+=
"}
\n
"
return
res
def
gen_model
(
n
,
m
,
t
,
bits
,
p
,
gs
):
# get parameters
if
bits
==
3
:
packed
=
32
unroll
=
3
nu
=
1
# args.n
mu
=
32
tu
=
32
else
:
packed
=
32
//
bits
unroll
=
2
nu
=
1
# args.n
mu
=
16
tu
=
32
# compute the parameters from the model
nb
=
n
# it's always small for transformers
mb
,
tb
=
mem_model
(
n
,
m
,
t
,
mu
,
tu
,
bits
,
l1
,
p
,
gs
)
split
=
np
.
ones
(
p
)
split
=
split
*
tb
while
np
.
sum
(
split
)
<
t
:
split
=
split
+
tb
idx
=
p
-
1
while
np
.
sum
(
split
)
>
t
:
split
[
idx
]
=
split
[
idx
]
-
tb
idx
=
idx
-
1
assert
np
.
sum
(
split
)
==
t
split
=
split
.
astype
(
int
)
tt
=
int
(
split
[
0
])
if
split
[
0
]
==
split
[
-
1
]:
cutoff
=
int
(
p
+
1
)
else
:
cutoff
=
int
(
idx
+
1
)
if
gs
==
-
1
:
code
=
qforward
(
nu
,
mu
,
tu
,
p
,
unroll
,
n
=
n
,
m
=
m
,
t
=
t
,
nb
=
nb
,
mb
=
mb
,
tb
=
tb
,
tt
=
tt
,
bits
=
bits
,
cutoff
=
cutoff
,
module
=
False
,
)
else
:
code
=
qforward
(
nu
,
mu
,
tu
,
p
,
unroll
,
n
=
n
,
m
=
m
,
t
=
t
,
nb
=
nb
,
mb
=
mb
,
tb
=
tb
,
tt
=
tt
,
bits
=
bits
,
cutoff
=
cutoff
,
gs
=
True
,
gs_val
=
gs
,
module
=
False
,
)
code
+=
pack_in
(
n
,
m
,
nb
,
mb
)
# code += pack_qw(m, t, mb, tb, tb, bits=bits)#, cutoff=cutoff)
code
+=
pack_qw
(
m
,
t
,
mb
,
tb
,
tu
,
bits
=
bits
)
code
+=
pack_out
(
n
,
t
,
nb
,
tb
)
code
+=
print_parameters
(
bits
,
n
,
m
,
t
,
nb
,
mb
,
tb
,
mu
,
nu
,
tu
,
unroll
,
p
)
with
open
(
"./autogptq_extension/qigen/forward.h"
,
"w"
)
as
f
:
f
.
write
(
macros
())
f
.
write
(
code
)
def
gen_and_compile
(
n
,
m
,
t
,
nb
,
mb
,
tb
,
nu
,
mu
,
tu
,
p
,
unroll
,
bits
=
4
,
gs
=-
1
,
module
=
False
):
split
=
np
.
ones
(
p
)
split
=
split
*
tb
while
np
.
sum
(
split
)
<
t
:
split
=
split
+
tb
idx
=
p
-
1
while
np
.
sum
(
split
)
>
t
:
split
[
idx
]
=
split
[
idx
]
-
tb
idx
=
idx
-
1
assert
np
.
sum
(
split
)
==
t
split
=
split
.
astype
(
int
)
tt
=
int
(
split
[
0
])
if
split
[
0
]
==
split
[
-
1
]:
cutoff
=
int
(
p
+
1
)
else
:
cutoff
=
int
(
idx
+
1
)
if
gs
==
-
1
:
code
=
qforward
(
nu
,
mu
,
tu
,
p
,
unroll
,
n
=
n
,
m
=
m
,
t
=
t
,
nb
=
nb
,
mb
=
mb
,
tb
=
tb
,
tt
=
tt
,
bits
=
bits
,
cutoff
=
cutoff
,
module
=
False
,
)
else
:
code
=
qforward
(
nu
,
mu
,
tu
,
p
,
unroll
,
n
=
n
,
m
=
m
,
t
=
t
,
nb
=
nb
,
mb
=
mb
,
tb
=
tb
,
tt
=
tt
,
bits
=
bits
,
cutoff
=
cutoff
,
gs
=
True
,
gs_val
=
gs
,
module
=
False
,
)
code
+=
pack_in
(
n
,
m
,
nb
,
mb
)
code
+=
pack_qw
(
m
,
t
,
mb
,
tb
,
tu
,
bits
=
bits
)
code
+=
pack_out
(
n
,
t
,
nb
,
tb
)
if
module
:
code
+=
print_parameters_module
(
bits
,
mu
,
nu
,
tu
,
unroll
,
p
,
gs
=
gs
)
else
:
code
+=
print_parameters
(
bits
,
n
,
m
,
t
,
nb
,
mb
,
tb
,
mu
,
nu
,
tu
,
unroll
,
p
,
gs
=
gs
)
# write the code to a file called forward.h
with
open
(
"./autogptq_extension/qigen/forward.h"
,
"w"
)
as
f
:
f
.
write
(
macros
())
f
.
write
(
code
)
# g++ mmm_test.cpp -O3 -ftree-vectorize -mfma -mavx -mavx2 -fno-signaling-nans -fno-trapping-math -fopenmp -o mmm_test
start
=
time
.
time
()
if
not
module
:
subprocess
.
check_output
(
[
"g++"
,
"-O3"
,
"-o"
,
"./autogptq_extension/qigen/mmm_test"
,
"./autogptq_extension/qigen/mmm_test.cpp"
,
"-mavx"
,
"-mfma"
,
"-mavx2"
,
"-ftree-vectorize"
,
"-fno-signaling-nans"
,
"-fno-trapping-math"
,
"-march=native"
,
"-fopenmp"
,
]
)
subprocess
.
check_output
(
[
"./autogptq_extension/qigen/mmm_test"
,
f
"
{
n
}
"
,
f
"
{
m
}
"
,
f
"
{
t
}
"
,
f
"
{
bits
}
"
,
f
"
{
gs
}
"
,
]
)
else
:
subprocess
.
check_output
(
[
"g++"
,
"-O3"
,
"-o"
,
"./autogptq_extension/qigen/mmm"
,
"./autogptq_extension/qigen/mmm.cpp"
,
"-mavx"
,
"-mfma"
,
"-mavx2"
,
"-ftree-vectorize"
,
"-fno-signaling-nans"
,
"-fno-trapping-math"
,
"-march=native"
,
"-fopenmp"
,
]
)
subprocess
.
check_output
(
[
"./autogptq_extension/qigen/mmm"
,
f
"
{
n
}
"
,
f
"
{
m
}
"
,
f
"
{
t
}
"
,
f
"
{
bits
}
"
,
f
"
{
gs
}
"
,
]
)
end
=
time
.
time
()
-
start
return
end
def
grid
():
tt
=
64
for
p
in
[
32
]:
# for n in [1, 10]:
for
n
in
[
1
]:
for
m
in
[
4096
]:
for
t
in
[
4096
]:
# for mb in range(1,m):
# for mb in range(32,512,32):
# for mb in [64, 128, 256, 512, 1024, 2048]:
for
mb
in
[
512
,
1024
,
2048
]:
if
m
%
mb
==
0
:
# for tb in range(8,t,8):
# for tb in range(32,512,32):
# for tb in [16, 32, 64]:#, 128, 192, 256]:
# for tb in [32]:#, 128, 192, 256]:
for
tb
in
[
128
,
256
]:
if
t
%
tb
==
0
:
# for mu in range(32,mb,32):
for
mu
in
[
16
,
32
]:
if
mb
%
mu
==
0
:
# for tu in range(8,tb,8):
# for tu in [16, 32]:
for
tu
in
[
16
,
32
,
64
,
128
]:
if
tb
%
tu
==
0
:
for
gs
in
[
-
1
,
128
,
64
,
32
,
16
]:
# for bits in [2, 3, 4]:
for
bits
in
[
4
,
3
,
2
]:
if
bits
==
3
:
for
u
in
[
5
]:
gen_and_compile
(
n
,
m
,
t
,
n
,
mb
,
tb
,
1
,
mu
,
tu
,
p
,
u
,
bits
=
bits
,
gs
=
gs
,
)
else
:
for
u
in
[
1
,
2
,
4
,
8
]:
gen_and_compile
(
n
,
m
,
t
,
n
,
mb
,
tb
,
1
,
mu
,
tu
,
p
,
u
,
bits
=
bits
,
gs
=
gs
,
)
def
forward_module_gs
(
nu
,
mu
,
tu
,
p
,
unroll
,
bits
):
# packed = 32 // bits
if
bits
==
3
:
packed
=
32
loopguard
=
packed
else
:
packed
=
32
//
bits
loopguard
=
packed
# compute the parameters from the model
accumulators
=
""
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
accumulators
+=
f
"__m256 acc
{
i
}
_
{
j
}
= _mm256_setzero_ps();
\n
"
store
=
""
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"__m256 o
{
i
}
_
{
j
}
= _mm256_loadu_ps(&output[base_output + j + (i1+
{
i
}
)*t + j1+
{
j
}
]);
\n
"
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"__m256 s
{
i
}
_
{
j
}
= _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+
{
j
}
]);
\n
"
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"__m256 f
{
i
}
_
{
j
}
= _mm256_fmadd_ps(acc
{
i
}
_
{
j
}
, s
{
i
}
_
{
j
}
, o
{
i
}
_
{
j
}
);
\n
"
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"_mm256_storeu_ps(&output[base_output + j + (i1+
{
i
}
)*t + j1+
{
j
}
], f
{
i
}
_
{
j
}
);
\n
"
ugemm
=
""
if
bits
==
3
:
ugemm
+=
"int j1 = 0;
\n
"
ugemm
+=
"int jw = 0;
\n
"
ugemm
+=
f
"for(; j1 < tb-tu+1; j1+=tu, jw+=
{
tu
*
3
}
)"
ugemm
+=
"{
\n
"
else
:
ugemm
+=
"int j1 = 0;
\n
"
ugemm
+=
"for(; j1 < tb-tu+1; j1+=tu) {
\n
"
ugemm
+=
"for(int k1 = 0; k1 < mb; k1+=gs) {
\n
"
ugemm
+=
accumulators
ugemm
+=
f
"for(int k2 = k1; k2 < k1+gs; k2+=
{
loopguard
}
)
\n
"
ugemm
+=
"{
\n
"
ugemm
+=
block
(
nu
,
mu
,
tu
,
16
,
packed
,
unroll
,
bits
)
ugemm
+=
"}
\n
"
ugemm
+=
store
ugemm
+=
"}
\n
"
ugemm
+=
"}
\n
"
res
=
""
res
+=
"inline
\n
"
res
+=
f
"void q
{
bits
}
gemm_gs(const float* __restrict__ input,
\n
"
res
+=
" const int* __restrict__ W,
\n
\
const float* __restrict__ scales,
\n
"
res
+=
"const float* __restrict__ zeros,
\n
"
res
+=
" const float* __restrict__ bias,
\n
"
res
+=
" const float* __restrict__ sums,
\n
"
res
+=
" float* __restrict__ output,
\n
\
const int n,
\n
\
const int m,
\n
\
const int t,
\n
\
const int nb,
\n
\
const int mb,
\n
\
const int tb,
\n
\
int ogtt,
\n
\
const int gs,
\n\
const int cutoff){
\n
"
res
+=
f
"#pragma omp parallel num_threads(
{
p
}
)
\n
"
res
+=
"{
\n
"
res
+=
" int tid;
\n
"
res
+=
f
" const int mu =
{
mu
}
;
\n
"
res
+=
f
" const int nu =
{
nu
}
;
\n
"
res
+=
f
" const int tu =
{
tu
}
;
\n
"
res
+=
" const int on = n / nb;
\n
"
res
+=
" const int om = m / mb;
\n
"
mask
=
(
2
**
bits
)
-
1
res
+=
f
"const __m256i mask = _mm256_set1_epi32(
{
mask
}
);
\n
"
if
bits
==
3
:
res
+=
"const __m256i mask4 = _mm256_set1_epi32(4);
\n
"
res
+=
"const __m256i mask6 = _mm256_set1_epi32(6);
\n
"
res
+=
"tid = omp_get_thread_num();
\n
"
res
+=
"int tt = ogtt;
\n
"
res
+=
"if(tid >= cutoff){
\n
"
res
+=
"tt -= tb;
\n
"
res
+=
"}
\n
"
res
+=
"const int base_output = tid >= cutoff ?
\n
\
(tid-cutoff)*tt + (tt+tb)*cutoff:
\n
\
tid*tt;
\n
"
# is this >= cutoff or > cutoff?
if
bits
!=
3
:
res
+=
f
"const int base_W = tid >= cutoff ?
\n
\
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/
{
packed
}
:
\n
\
tid*tt*m/
{
packed
}
;
\n
"
else
:
res
+=
f
"const int base_W = tid >= cutoff ?
\n
\
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/
{
packed
}
*3:
\n
\
tid*tt*m/
{
packed
}
*3;
\n
"
res
+=
"for(int j = 0; j < tt; j+=tb){
\n
"
res
+=
"for(int i = 0; i < on; i++) {
\n
"
res
+=
"for(int k = 0; k < om; k++) {
\n
"
res
+=
"for(int i1 = 0; i1 < nb; i1+=nu) {
\n
"
res
+=
ugemm
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"const int ngs = m/gs;
\n
"
res
+=
"#pragma omp barrier
\n
"
# res += "#pragma omp for collapse(2)\n"
res
+=
"for (int i = 0; i < n; i++) {
\n
"
# res += f" for (int j = 0; j < t; j+={tu})"
res
+=
f
"for (int j = 0; j < tt; j+=
{
tu
}
)"
res
+=
"{
\n
"
# for i in range(0,tu,8):
# res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 acc
{
i
}
= _mm256_setzero_ps();
\n
"
res
+=
"for (int i1 = 0; i1 < ngs; i1++){
\n
"
res
+=
"__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
# res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[i1 * t + j + {i}]);\n"
res
+=
f
"__m256 z
{
i
}
= _mm256_loadu_ps(&zeros[base_output + i1* t + j +
{
i
}
]);
\n
"
# for i in range(0,tu,8):
# res += f"__m256 s{i} = _mm256_loadu_ps(&scales[i1 * t + j + {i}]);\n"
# for i in range(0,tu,8):
# res += f"__m256 zr{i} = _mm256_mul_ps(z{i}, r);\n"
# for i in range(0,tu,8):
# res += f"acc{i} = _mm256_fmadd_ps(zr{i}, s{i}, acc{i});\n"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"acc
{
i
}
= _mm256_fmadd_ps(z
{
i
}
, r, acc
{
i
}
);
\n
"
# for i in range(0,tu,8):
# res += f"__m256 zr{i} = _mm256_mul_ps(z{i}, r);\n"
# for i in range(0,tu,8):
# res += f"o{i} = _mm256_fnmadd_ps(zr{i}, s{i}, o{i});\n"
res
+=
"}
\n
"
for
i
in
range
(
0
,
tu
,
8
):
# res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n"
res
+=
f
"__m256 o
{
i
}
= _mm256_loadu_ps(&output[i*t + base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 o1
{
i
}
= _mm256_sub_ps(o
{
i
}
, acc
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
# res += f"_mm256_storeu_ps(&output[i*t + j + {i}], o1{i});\n"
res
+=
f
"_mm256_storeu_ps(&output[i*t + base_output + j +
{
i
}
], o1
{
i
}
);
\n
"
res
+=
" }
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
# wrapper for qgemm if we call from cpp
res
+=
f
"inline void forward
{
bits
}
_gs_cpu(
\n
"
res
+=
"torch::Tensor in, torch::Tensor weight, torch::Tensor out,
\n
"
res
+=
"torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,
\n
"
res
+=
"int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){
\n
"
res
+=
"int* W = weight.data_ptr<int>();
\n
"
res
+=
"float* input = in.data_ptr<float>();
\n
"
res
+=
"float* b = bias.data_ptr<float>();
\n
"
res
+=
"float* s = scales.data_ptr<float>();
\n
"
# res += "int* z = zeros.data_ptr<int>();\n"
res
+=
"float* z = zeros.data_ptr<float>();
\n
"
res
+=
"float* r = sums.data_ptr<float>();
\n
"
res
+=
"float* O = out.data_ptr<float>();
\n
"
res
+=
"
\n
"
res
+=
f
"q
{
bits
}
gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff);
\n
"
res
+=
"}
\n
"
return
res
def
forward_module
(
nu
,
mu
,
tu
,
p
,
unroll
,
bits
):
# packed = 32 // bits
if
bits
==
3
:
packed
=
32
loopguard
=
packed
else
:
packed
=
32
//
bits
loopguard
=
packed
# compute the parameters from the model
accumulators
=
""
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
accumulators
+=
f
"__m256 acc
{
i
}
_
{
j
}
= _mm256_loadu_ps(&output[base_output + j + (i1+
{
i
}
)*t + j1+
{
j
}
]);
\n
"
store
=
""
for
i
in
range
(
nu
):
for
j
in
range
(
0
,
tu
,
8
):
store
+=
f
"_mm256_storeu_ps(&output[base_output + j + (i1+
{
i
}
)*t + j1+
{
j
}
], acc
{
i
}
_
{
j
}
);
\n
"
ugemm
=
""
if
bits
==
3
:
ugemm
+=
"int jw = 0;
\n
"
ugemm
+=
f
"for(; j1 < tb-tu+1; j1+=tu, jw+=
{
tu
*
3
}
)"
ugemm
+=
"{
\n
"
else
:
ugemm
+=
"for(; j1 < tb-tu+1; j1+=tu) {
\n
"
ugemm
+=
accumulators
ugemm
+=
"for(int k1 = 0; k1 < mb; k1+=mu) {
\n
"
ugemm
+=
f
"for(int k2 = k1; k2 < k1+mu; k2+=
{
loopguard
}
)"
ugemm
+=
"{
\n
"
ugemm
+=
block
(
nu
,
mu
,
tu
,
16
,
packed
,
unroll
,
bits
)
ugemm
+=
"}
\n
"
ugemm
+=
"}
\n
"
ugemm
+=
store
ugemm
+=
"}
\n
"
res
=
""
res
+=
"inline
\n
"
res
+=
f
"void q
{
bits
}
gemm(const float* __restrict__ input,
\n
"
res
+=
"const int* __restrict__ W,
\n
"
res
+=
"const float* __restrict__ scales,
\n
"
# res += "const int* __restrict__ zeros, \n"
res
+=
"const float* __restrict__ zeros,
\n
"
res
+=
"const float* __restrict__ bias,
\n
"
res
+=
"const float* __restrict__ sums,"
res
+=
"float* __restrict__ output,
\n
\
const int n,
\n
\
const int m,
\n
\
const int t,
\n
\
const int nb,
\n
\
const int mb,
\n
\
const int tb,
\n
\
int ogtt,
\n
\
const int cutoff){
\n
"
res
+=
f
"#pragma omp parallel num_threads(
{
p
}
)
\n
"
res
+=
"{
\n
"
res
+=
"int tid, nthreads;
\n
"
res
+=
f
"const int mu =
{
mu
}
;
\n
"
res
+=
f
"const int nu =
{
nu
}
;
\n
"
res
+=
f
"const int tu =
{
tu
}
;
\n
"
res
+=
"const int on = n / nb;
\n
"
res
+=
"const int om = m / mb;
\n
"
mask
=
(
2
**
bits
)
-
1
res
+=
f
"const __m256i mask = _mm256_set1_epi32(
{
mask
}
);
\n
"
if
bits
==
3
:
res
+=
"const __m256i mask4 = _mm256_set1_epi32(4);
\n
"
res
+=
"const __m256i mask6 = _mm256_set1_epi32(6);
\n
"
res
+=
"tid = omp_get_thread_num();
\n
"
# res += " std::cout << \"thread \" << tid << \" started\" << std::endl;\n"
res
+=
"nthreads = omp_get_num_threads();
\n
"
res
+=
"int tt = ogtt;
\n
"
res
+=
"if(tid >= cutoff){
\n
"
res
+=
"tt -= tb;
\n
"
res
+=
"}
\n
"
res
+=
"const int base_output = tid >= cutoff ?
\n
\
(tid-cutoff)*tt + (tt+tb)*cutoff:
\n
\
tid*tt;
\n
"
# is this >= cutoff or > cutoff?
if
bits
!=
3
:
res
+=
f
"const int base_W = tid >= cutoff ?
\n
\
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/
{
packed
}
:
\n
\
tid*tt*m/
{
packed
}
;
\n
"
else
:
res
+=
f
"const int base_W = tid >= cutoff ?
\n
\
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/
{
packed
}
*3:
\n
\
tid*tt*m/
{
packed
}
*3;
\n
"
res
+=
"for(int j = 0; j < tt; j+=tb){
\n
"
res
+=
"for(int i = 0; i < on; i++) {
\n
"
res
+=
"for(int k = 0; k < om; k++) {
\n
"
res
+=
"for(int i1 = 0; i1 < nb; i1+=nu) {
\n
"
res
+=
"int j1 = 0;
\n
"
res
+=
ugemm
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
# res += "#pragma omp barrier\n"
# res += "#pragma omp for\n"
res
+=
"for (int i = 0; i < n; i++) {
\n
"
res
+=
"__m256 r = _mm256_set1_ps(sums[i]);
\n
"
# res += f"for (int j = 0; j < t; j+={tu})"
res
+=
f
"for (int j = 0; j < tt; j+=
{
tu
}
)"
res
+=
"{
\n
"
for
i
in
range
(
0
,
tu
,
8
):
# res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n"
res
+=
f
"__m256 o
{
i
}
= _mm256_loadu_ps(&output[i*t + base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 z
{
i
}
= _mm256_loadu_ps(&zeros[base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 s
{
i
}
= _mm256_loadu_ps(&scales[base_output + j +
{
i
}
]);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 zr
{
i
}
= _mm256_fnmadd_ps(z
{
i
}
, r, o
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"__m256 o2
{
i
}
= _mm256_mul_ps(zr
{
i
}
, s
{
i
}
);
\n
"
for
i
in
range
(
0
,
tu
,
8
):
res
+=
f
"_mm256_storeu_ps(&output[i*t + base_output + j +
{
i
}
], o2
{
i
}
);
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
# wrapper for qgemm if we call from cpp
res
+=
f
"inline void forward
{
bits
}
_cpu(
\n
"
res
+=
"torch::Tensor in, torch::Tensor weight, torch::Tensor out,
\n
"
res
+=
"torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,
\n
"
res
+=
"int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){
\n
"
res
+=
"int* W = weight.data_ptr<int>();
\n
"
res
+=
"float* input = in.data_ptr<float>();
\n
"
res
+=
"float* b = bias.data_ptr<float>();
\n
"
res
+=
"float* s = scales.data_ptr<float>();
\n
"
# res += "int* z = zeros.data_ptr<int>();\n"
res
+=
"float* z = zeros.data_ptr<float>();
\n
"
res
+=
"float* r = sums.data_ptr<float>();
\n
"
res
+=
"float* O = out.data_ptr<float>();
\n
"
res
+=
"
\n
"
res
+=
f
"q
{
bits
}
gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff);
\n
"
res
+=
"}
\n
"
return
res
def
unpack_zeros
(
bits
):
res
=
""
res
+=
f
"void unpack_zeros
{
bits
}
_cpu(const int* zv, float* ov, int n, int m)"
packed
=
32
//
bits
mask
=
(
2
**
bits
)
-
1
res
+=
"{
\n
const __m256i ones = _mm256_set1_epi32(1);
\n
"
res
+=
f
"const __m256i mask = _mm256_set1_epi32(
{
mask
}
);
\n
"
if
bits
==
4
:
res
+=
"const __m256i shift = _mm256_set_epi32(28,24,20,16,12,8,4,0);
\n
"
elif
bits
==
3
:
pass
elif
bits
==
2
:
res
+=
"const __m256i shift0 = _mm256_set_epi32(30,28,26,24,22,20,18,16);
\n
"
res
+=
"const __m256i shift1 = _mm256_set_epi32(14,12,10,8,6,4,2,0);
\n
"
else
:
print
(
"ERROR"
)
res
+=
"for(int i = 0; i < n; i++){
\n
"
if
bits
==
4
:
res
+=
"for(int j = 0; j < m; j+=8){
\n
"
res
+=
"__m256i z = _mm256_set1_epi32(zv[i*m/8 + j/8]);
\n
"
res
+=
"__m256i z0 = _mm256_srlv_epi32(z, shift);
\n
"
res
+=
"__m256i z1 = _mm256_and_si256(z0, mask);
\n
"
res
+=
"__m256i z2 = _mm256_add_epi32(z1, ones);
\n
"
res
+=
"__m256 z3 = _mm256_cvtepi32_ps(z2);
\n
"
res
+=
"_mm256_storeu_ps(&ov[i*m +j], z3);
\n
"
elif
bits
==
2
:
res
+=
f
"for (int j = 0; j < m; j+=
{
packed
}
)"
res
+=
"{
\n
"
res
+=
f
"for (int k = 0; k <
{
packed
}
; k++)"
res
+=
"{
\n
"
res
+=
f
"ov[i*m + j+k] = (((zv[j/
{
packed
}
] >> (
{
bits
}
*k)) &
{
mask
}
)+1);
\n
"
res
+=
"}
\n
"
# res += "for(int j = 0; j < m; j+=16){\n"
# res += "__m256i z = _mm256_set1_epi32(zv[i*m/16 + j/16]);\n"
# res += "__m256i z00 = _mm256_srlv_epi32(z, shift0);\n"
# res += "__m256i z01 = _mm256_srlv_epi32(z, shift1);\n"
# res += "__m256i z10 = _mm256_and_si256(z00, mask);\n"
# res += "__m256i z11 = _mm256_and_si256(z01, mask);\n"
# res += "__m256i z20 = _mm256_add_epi32(z10, ones);\n"
# res += "__m256i z21 = _mm256_add_epi32(z11, ones);\n"
# res += "__m256 z30 = _mm256_cvtepi32_ps(z20);\n"
# res += "__m256 z31 = _mm256_cvtepi32_ps(z21);\n"
# res += "_mm256_storeu_ps(&ov[i*m +j], z30);\n"
# res += "_mm256_storeu_ps(&ov[i*m +j+8], z31);\n"
elif
bits
==
3
:
# pass
res
+=
"for(int j = 0; j < m; j+=32){
\n
"
res
+=
'std::cout<<"not yet implemented"<<std::endl;
\n
'
# res += "unsigned int z0 = zv[i*m+j/32*3];\n"
# res += "unsigned int z1 = zv[i*m+j/32*3+1];\n"
# res += "unsigned int z2 = zv[i*m+j/32*3+2];\n"
# for i in range(10):
# res += f"unsigned int z0{i} = ((z0 >> {29 - i*3}) & 7) + 1;\n"
# for i in range(10):
# res += f"ov[i*m + j + {i}] = z0{i} * sv[i*m + j + {i}];\n"
# res += "unsigned int t0 = ((z0<<1 & 6) | (z1>>31)) + 1;\n"
# res += "ov[i*m + j + 10] = t0 * sv[i*m + j + 10];\n"
# for i in range(10):
# res += f"unsigned int z1{i} = ((z1 >> {28 - i*3}) & 7) + 1;\n"
# for i in range(10):
# res += f"ov[i*m + j + {11 + i}] = z1{i} * sv[i*m + j + {11 + i}];\n"
# res += "unsigned int t1 = ((z1<<2 & 6) | (z2>>30)) + 1;\n"
# res += "ov[i*m + j + 21] = t1 * sv[i*m + j + 21];\n"
# for i in range(10):
# res += f"unsigned int z2{i} = ((z2 >> {27 - i*3}) & 7) + 1;\n"
# for i in range(10):
# res += f"ov[i*m + j + {22 + i}] = z2{i} * sv[i*m + j + {22 + i}];\n"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
# write the pybind interface
res
+=
f
"void unpack_zeros
{
bits
}
(torch::Tensor zeros, torch::Tensor out, int N, int M)"
res
+=
"{
\n
int* Z = zeros.data_ptr<int>();
\n
"
res
+=
"float* O = out.data_ptr<float>();
\n
"
res
+=
f
"unpack_zeros
{
bits
}
_cpu(Z, O, N, M);
\n
"
res
+=
"}
\n
"
return
res
def
gen_module
(
r
,
p
,
bits_list
=
[
2
,
3
,
4
]):
code
=
""
for
bits
in
bits_list
:
if
bits
==
3
:
unroll
=
3
nu
=
1
# args.n
mu
=
32
tu
=
32
else
:
unroll
=
2
nu
=
1
# args.n
mu
=
16
# mu = 32
tu
=
32
code
+=
qforward
(
nu
,
mu
,
tu
,
p
,
unroll
,
bits
=
bits
,
module
=
True
,
gs
=
False
)
code
+=
qforward
(
nu
,
mu
,
tu
,
p
,
unroll
,
bits
=
bits
,
module
=
True
,
gs
=
True
)
code
+=
pack_qw_module
(
bits
)
code
+=
unpack_zeros
(
bits
)
with
open
(
"./autogptq_extension/qigen/backend.cpp"
,
"w"
)
as
f
:
f
.
write
(
template
.
includes
())
f
.
write
(
template
.
quant_scalar
())
f
.
write
(
compute_reduction
(
p
))
f
.
write
(
unquantize_sim
(
p
))
f
.
write
(
code
)
f
.
write
(
template
.
module
(
bits_list
))
def
compute_reduction
(
p
):
res
=
""
res
+=
"void compute_reduction_cpu(const float* in, float* out, int n, int m, int gs){
\n
"
res
+=
f
"#pragma omp parallel num_threads(
{
p
}
)
\n
"
res
+=
"{
\n
"
res
+=
"#pragma omp for collapse(2)
\n
"
res
+=
"for(int i = 0; i < n; i++){
\n
"
res
+=
"for(int j0 = 0; j0 < m; j0+=gs){
\n
"
res
+=
"__m256 acc = _mm256_setzero_ps();
\n
"
res
+=
"for(int j1 = j0; j1 < j0+gs; j1+=8){
\n
"
res
+=
"__m256 x = _mm256_loadu_ps(&in[i*m + j1]);
\n
"
res
+=
"acc = _mm256_add_ps(acc, x);
\n
"
res
+=
"}
\n
"
# compute simd add reduction
res
+=
"const __m128 hiQuad = _mm256_extractf128_ps(acc, 1);
\n
"
res
+=
"const __m128 loQuad = _mm256_castps256_ps128(acc);
\n
"
res
+=
"const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad);
\n
"
res
+=
"const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad);
\n
"
res
+=
"const __m128 sumDual = _mm_add_ps(sumQuad, hiDual);
\n
"
res
+=
"const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1);
\n
"
res
+=
"const __m128 sum = _mm_add_ss(hi, sumDual);
\n
"
res
+=
"out[(i*m + j0)/gs] = _mm_cvtss_f32(sum);
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
# write the pybind interface
res
+=
"void compute_reduction(torch::Tensor in, torch::Tensor out, int N, int M, int gs)"
res
+=
"{
\n
float* I = in.data_ptr<float>();
\n
"
res
+=
"float* O = out.data_ptr<float>();
\n
"
res
+=
"compute_reduction_cpu(I, O, N, M, gs);
\n
"
res
+=
"}
\n
"
return
res
def
unquantize_sim
(
p
):
res
=
""
res
+=
"void unquantize_sim_cpu(const int* in, float* out, float* s, float* z, int n, int m, int bits, int gs){
\n
"
res
+=
f
"#pragma omp parallel num_threads(
{
p
}
)
\n
"
res
+=
"{
\n
"
res
+=
"int packed = 32/bits;
\n
"
res
+=
"int mask = (1<<bits) - 1;
\n
"
res
+=
"#pragma omp for
\n
"
res
+=
"for(int i0 = 0; i0 < n; i0+=gs){
\n
"
res
+=
"int row = i0 / gs;
\n
"
res
+=
"for(int i1 = i0; i1 < i0+gs; i1+=packed){
\n
"
res
+=
"for(int j0 = 0; j0 < m; j0++){
\n
"
res
+=
"for(int k = 0; k < packed; k++){
\n
"
res
+=
"out[(i1+k)*m + j0] = ((float)((in[i1*m/packed + j0] >> (bits*k)) & mask) - z[(row)*m + j0]) * s[(row)*m + j0];
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
# write the pybind interface
res
+=
"void unquantize_sim(torch::Tensor in, torch::Tensor out, torch::Tensor s, torch::Tensor z, int N, int M, int bits, int gs)"
res
+=
"{
\n
int* I = in.data_ptr<int>();
\n
"
res
+=
"float* O = out.data_ptr<float>();
\n
"
res
+=
"float* S = s.data_ptr<float>();
\n
"
res
+=
"float* Z = z.data_ptr<float>();
\n
"
res
+=
"unquantize_sim_cpu(I, O, S, Z, N, M, bits, gs);
\n
"
res
+=
"}
\n
"
return
res
def
pack_qw_module
(
bits
):
packed
=
32
//
bits
res
=
""
if
bits
==
3
:
res
+=
f
"inline void pack
{
bits
}
_qw_inner(int* A, int* B, const int N, const int M, const int nb, const int mb, int cutoff)"
res
+=
"{
\n
"
res
+=
"// copy the full matrix A in blocked format into B
\n
"
res
+=
"uint64_t idx = 0;
\n
"
# res += f" const {int(tb)};\n"
res
+=
"for(int j = 0, tid = 0; j < M; j+=mb, tid++){
\n
"
res
+=
"for(int i = 0; i < N; i+=nb){
\n
\
for(int ii = i; ii < mymin(i+nb, N); ii+=3){
\n
\
for(int jj = j; jj < mymin(j+mb, M); jj+=8){
\n
\
for(int iii = ii; iii < ii + 3; iii++){
\n
\
for(int jjj = jj; jjj < jj + 8; jjj++){
\n
\
B[idx] = A[iii*M+jjj];
\n
\
idx++;
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
\
}
\n
"
res
+=
f
"inline void pack
{
bits
}
_w_cpu(
\n
"
res
+=
"torch::Tensor in, torch::Tensor out,
\n
"
res
+=
"int N, int M, int nb, int mb, int cutoff){
\n
"
res
+=
"int* input = in.data_ptr<int>();
\n
"
res
+=
"int* O = out.data_ptr<int>();
\n
"
res
+=
f
"pack
{
bits
}
_qw_inner(input, O, N, M, nb, mb, cutoff);
\n
"
res
+=
"}
\n
"
return
res
else
:
# in case i do this for python i can just add the n,m,nb,mb as function parameters
res
+=
f
"inline void pack
{
bits
}
_qw_inner(int* A, int* B, const int N, const int M, const int nb, int mb, int cutoff)"
res
+=
"{
\n
"
res
+=
"// copy the full matrix A in blocked format into B
\n
"
res
+=
"uint64_t idx = 0;
\n
"
res
+=
"for(int j = 0, tid = 0; j < M; j+=mb, tid++){
\n
"
res
+=
"for(int i = 0; i < N; i+=nb){
\n
\
for(int ii = i; ii < mymin(i+nb, N); ii++){
\n
\
for(int jj = j; jj < mymin(j+mb, M); jj++){
\n
\
B[idx] = A[ii*M+jj];
\n
\
idx++;
\n
\
}
\n
\
}
\n
\
}
\n
"
res
+=
"}
\n
"
res
+=
"}
\n
"
res
+=
f
"inline void pack
{
bits
}
_w_cpu(
\n
"
res
+=
"torch::Tensor in, torch::Tensor out,
\n
"
res
+=
"int N, int M, int nb, int mb, int cutoff){
\n
"
res
+=
"int* input = in.data_ptr<int>();
\n
"
res
+=
"int* O = out.data_ptr<int>();
\n
"
res
+=
f
" pack
{
bits
}
_qw_inner(input, O, N, M, nb, mb, cutoff);
\n
"
res
+=
"}
\n
"
return
res
def
gen_module_search
(
r
,
p
,
bits_list
=
[
2
,
3
,
4
]):
# print measurements to a tmp file and read back best micro parameters
code
=
""
# Opening in 'w' mode overwrites tmp.csv.
with
open
(
"./autogptq_extension/qigen/tmp.csv"
,
"w"
)
as
f
:
f
.
write
(
"bits,nu,mu,tu,unroll,p,gs,time
\n
"
)
n
,
m
,
t
,
nb
,
mb
,
tb
=
1
,
4096
,
4096
,
1
,
1024
,
32
for
mu
in
[
16
]:
for
tu
in
[
16
,
32
,
64
]:
if
tb
%
tu
==
0
:
for
gs
in
[
-
1
,
64
]:
for
bits
in
[
4
,
3
,
2
]:
if
bits
==
3
:
for
u
in
[
5
]:
print
(
n
,
m
,
t
,
n
,
mb
,
tb
,
1
,
mu
,
tu
,
p
,
u
,
bits
,
gs
,
end
=
"
\r
"
,
flush
=
True
,
)
gen_and_compile
(
n
,
m
,
t
,
n
,
mb
,
tb
,
1
,
mu
,
tu
,
p
,
u
,
bits
=
bits
,
gs
=
gs
,
module
=
True
,
)
else
:
for
u
in
[
1
,
2
,
4
,
8
]:
print
(
n
,
m
,
t
,
n
,
mb
,
tb
,
1
,
mu
,
tu
,
p
,
u
,
bits
,
gs
,
end
=
"
\r
"
,
flush
=
True
,
)
gen_and_compile
(
n
,
m
,
t
,
n
,
mb
,
tb
,
1
,
mu
,
tu
,
p
,
u
,
bits
=
bits
,
gs
=
gs
,
module
=
True
,
)
df
=
pd
.
read_csv
(
"./autogptq_extension/qigen/tmp.csv"
)
for
bits
in
bits_list
:
bits_df
=
df
[
df
[
"bits"
]
==
bits
]
bits_nogs
=
bits_df
[
bits_df
[
"gs"
]
==
-
1
]
best
=
bits_nogs
[
bits_nogs
[
"time"
]
==
bits_nogs
[
"time"
].
min
()]
nu
=
int
(
best
[
"nu"
].
values
[
0
])
mu
=
int
(
best
[
"mu"
].
values
[
0
])
tu
=
int
(
best
[
"tu"
].
values
[
0
])
unroll
=
int
(
best
[
"unroll"
].
values
[
0
])
code
+=
qforward
(
nu
,
mu
,
tu
,
p
,
unroll
,
bits
=
bits
,
module
=
True
,
gs
=
False
)
bits_gs
=
bits_df
[
bits_df
[
"gs"
]
!=
-
1
]
best
=
bits_gs
[
bits_gs
[
"time"
]
==
bits_gs
[
"time"
].
min
()]
nu_gs
=
int
(
best
[
"nu"
].
values
[
0
])
mu_gs
=
int
(
best
[
"mu"
].
values
[
0
])
tu_gs
=
int
(
best
[
"tu"
].
values
[
0
])
unroll_gs
=
int
(
best
[
"unroll"
].
values
[
0
])
code
+=
qforward
(
nu_gs
,
mu_gs
,
tu_gs
,
p
,
unroll_gs
,
bits
=
bits
,
module
=
True
,
gs
=
True
)
code
+=
pack_qw_module
(
bits
)
code
+=
unpack_zeros
(
bits
)
with
open
(
"./autogptq_extension/qigen/backend.cpp"
,
"w"
)
as
f
:
f
.
write
(
template
.
includes
())
f
.
write
(
template
.
quant_scalar
())
f
.
write
(
compute_reduction
(
p
))
f
.
write
(
unquantize_sim
(
p
))
f
.
write
(
code
)
f
.
write
(
template
.
module
(
bits_list
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
"--t"
,
type
=
int
,
default
=
1024
)
parser
.
add_argument
(
"--nb"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--mb"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--tb"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--mu"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--nu"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--tu"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--bits"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--module"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--search"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--model"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--r"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--p"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--gs"
,
type
=
int
,
default
=-
1
)
args
=
parser
.
parse_args
()
if
args
.
module
and
args
.
search
:
gen_module_search
(
args
.
r
,
args
.
p
,
[
2
,
3
,
4
])
if
args
.
module
and
not
args
.
search
:
gen_module
(
args
.
r
,
args
.
p
,
[
2
,
3
,
4
])
if
args
.
search
and
not
args
.
module
:
grid
()
if
args
.
model
:
gen_model
(
args
.
n
,
args
.
m
,
args
.
t
,
args
.
bits
,
args
.
p
,
args
.
gs
)
Prev
1
…
3
4
5
6
7
8
9
10
11
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment