Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
01a5d18a
Unverified
Commit
01a5d18a
authored
Feb 29, 2024
by
CHU Tianxiang
Committed by
GitHub
Feb 28, 2024
Browse files
Add Support for 2/3/8-bit GPTQ Quantization Models (#2330)
parent
929b4f29
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1736 additions
and
229 deletions
+1736
-229
csrc/ops.h
csrc/ops.h
+4
-2
csrc/quantization/gptq/matrix_view.cuh
csrc/quantization/gptq/matrix_view.cuh
+123
-0
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+1326
-126
csrc/quantization/gptq/qdq_2.cuh
csrc/quantization/gptq/qdq_2.cuh
+87
-0
csrc/quantization/gptq/qdq_3.cuh
csrc/quantization/gptq/qdq_3.cuh
+141
-0
csrc/quantization/gptq/qdq_4.cuh
csrc/quantization/gptq/qdq_4.cuh
+6
-94
csrc/quantization/gptq/qdq_8.cuh
csrc/quantization/gptq/qdq_8.cuh
+40
-0
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+9
-7
No files found.
csrc/ops.h
View file @
01a5d18a
...
@@ -98,11 +98,13 @@ torch::Tensor gptq_gemm(
...
@@ -98,11 +98,13 @@ torch::Tensor gptq_gemm(
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
);
bool
use_exllama
,
int
bit
);
void
gptq_shuffle
(
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
);
torch
::
Tensor
q_perm
,
int
bit
);
void
moe_align_block_size
(
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
torch
::
Tensor
topk_ids
,
...
...
csrc/quantization/gptq/matrix_view.cuh
View file @
01a5d18a
...
@@ -146,6 +146,129 @@ public:
...
@@ -146,6 +146,129 @@ public:
__device__
__forceinline__
const
uint32_t
*
item_uint32_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
/
8
*
width
+
column
];
}
__device__
__forceinline__
const
uint32_t
*
item_uint32_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
/
8
*
width
+
column
];
}
};
};
class
MatrixView_q2_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q2_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
return
(
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
)
&
0x03
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
;
items
[
0
]
=
d
&
0x03
;
items
[
1
]
=
(
d
>>
2
)
&
0x03
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
;
items
[
0
]
=
d
&
0x03
;
items
[
1
]
=
(
d
>>
2
)
&
0x03
;
items
[
2
]
=
(
d
>>
4
)
&
0x03
;
items
[
3
]
=
(
d
>>
6
)
&
0x03
;
}
};
class
MatrixView_q3_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q3_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
z_w
=
column
*
3
/
32
;
int
z_mod
=
column
&
0x1f
;
if
(
z_mod
==
10
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
30
)
|
((
data
[
row
*
width
*
3
/
32
+
(
z_w
+
1
)]
<<
2
)
&
0x4
);
}
else
if
(
z_mod
==
21
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
31
)
|
((
data
[
row
*
width
*
3
/
32
+
(
z_w
+
1
)]
<<
1
)
&
0x6
);
}
else
if
(
z_mod
<
10
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
))
&
0x07
;
}
else
if
(
z_mod
<
21
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
-
32
))
&
0x07
;
}
else
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
-
64
))
&
0x07
;
}
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x1f
);
uint32_t
d
;
if
(
shift
<=
4
)
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
);
}
else
if
(
shift
==
8
)
{
d
=
(
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
24
)
|
((
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
+
1
]
&
0x0f
)
<<
8
);
}
else
if
(
shift
<=
16
)
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
-
32
);
}
else
if
(
shift
==
20
)
{
d
=
(
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
28
)
|
((
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
+
1
]
&
0xff
)
<<
4
);
}
else
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
-
64
);
}
items
[
0
]
=
d
&
0x07
;
items
[
1
]
=
(
d
>>
3
)
&
0x07
;
items
[
2
]
=
(
d
>>
6
)
&
0x07
;
items
[
3
]
=
(
d
>>
9
)
&
0x07
;
}
};
class
MatrixView_q8_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q8_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
8
;
return
(
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
)
&
0xff
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
8
;
uint32_t
d
=
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
;
items
[
0
]
=
d
&
0xff
;
items
[
1
]
=
(
d
>>
8
)
&
0xff
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
;
items
[
0
]
=
d
&
0xff
;
items
[
1
]
=
(
d
>>
8
)
&
0xff
;
items
[
2
]
=
(
d
>>
16
)
&
0xff
;
items
[
3
]
=
(
d
>>
24
)
&
0xff
;
}
};
}
// namespace gptq
}
// namespace gptq
}
// namespace vllm
}
// namespace vllm
#endif
#endif
csrc/quantization/gptq/q_gemm.cu
View file @
01a5d18a
...
@@ -13,7 +13,10 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
...
@@ -13,7 +13,10 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
#include "compat.cuh"
#include "compat.cuh"
#include "matrix_view.cuh"
#include "matrix_view.cuh"
#include "qdq_2.cuh"
#include "qdq_3.cuh"
#include "qdq_4.cuh"
#include "qdq_4.cuh"
#include "qdq_8.cuh"
namespace
vllm
{
namespace
vllm
{
namespace
gptq
{
namespace
gptq
{
...
@@ -22,6 +25,7 @@ namespace gptq {
...
@@ -22,6 +25,7 @@ namespace gptq {
#define BLOCK_M_SIZE_MAX 8
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_ROWS_8BIT 24
#define MAX_ALT_GEMM_ROWS 8
#define MAX_ALT_GEMM_ROWS 8
#define THREADS_X 32
#define THREADS_X 32
#define THREADS_Y 32
#define THREADS_Y 32
...
@@ -75,6 +79,106 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
...
@@ -75,6 +79,106 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
}
}
__forceinline__
__device__
half2
dot22_8
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
half2
dot22_16
(
half2
(
&
dq
)[
8
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
half2
dot22_32
(
half2
(
&
dq
)[
16
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
+=
1
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
float
dot22_8_f
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
__forceinline__
__device__
float
dot22_16_f
(
half2
(
&
dq
)[
8
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
__forceinline__
__device__
float
dot22_32_f
(
half2
(
&
dq
)[
16
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
+=
1
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
__forceinline__
__device__
half
dot22_8_h
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half
g_result
,
const
half
qs_h
)
{
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
float
result
=
{};
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
half2
w01
=
dq
[
i
];
float
w0
=
__low2float
(
w01
);
float
w1
=
__high2float
(
w01
);
float
x0
=
__half2float
(
*
a_ptr
++
);
float
x1
=
__half2float
(
*
a_ptr
++
);
result
=
fma
(
w0
,
x0
,
result
);
result
=
fma
(
w1
,
x1
,
result
);
}
float
qs
=
__half2float
(
qs_h
);
result
*=
qs
;
half
result_h
=
__float2half_rn
(
result
);
return
__hadd
(
result_h
,
g_result
);
}
__forceinline__
__device__
half
dot22_16_h
(
half2
(
&
dq
)[
8
],
const
half
*
a_ptr
,
const
half
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
half
result_h
=
__hadd
(
__low2half
(
result
),
__high2half
(
result
));
return
__hfma
(
result_h
,
qs_h
,
g_result
);
}
__forceinline__
__device__
half
dot22_32_h
(
half2
(
&
dq
)[
16
],
const
half
*
a_ptr
,
const
half
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
+=
1
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
half
result_h
=
__hadd
(
__low2half
(
result
),
__high2half
(
result
));
return
__hfma
(
result_h
,
qs_h
,
g_result
);
}
typedef
void
(
*
fp_gemm_half_q_half_gptq_kernel
)
typedef
void
(
*
fp_gemm_half_q_half_gptq_kernel
)
(
(
const
half
*
,
const
half
*
,
...
@@ -89,8 +193,9 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
...
@@ -89,8 +193,9 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const
int
*
const
int
*
);
);
template
<
bool
first_block
,
int
m_count
>
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_gptq_kernel
__global__
void
gemm_half_q_half_gptq_
4bit_
kernel
(
(
const
half
*
__restrict__
a
,
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
...
@@ -231,80 +336,794 @@ __global__ void gemm_half_q_half_gptq_kernel
...
@@ -231,80 +336,794 @@ __global__ void gemm_half_q_half_gptq_kernel
}
}
}
}
template
<
bool
first_block
,
int
m_count
>
fp_gemm_half_q_half_gptq_kernel
pick_gemm_half_q_half_gptq_kernel
(
bool
first_block
,
const
int
m_count
)
__global__
void
gemm_half_q_half_gptq_2bit_kernel
(
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_n
,
const
int
size_k
,
const
int
groups
,
const
int
*
__restrict__
b_q_perm
)
{
{
#if BLOCK_M_SIZE_MAX >= 1
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
if
(
m_count
==
1
)
return
gemm_half_q_half_gptq_kernel
<
true
,
1
>
;
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
#endif
MatrixView_q2_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
#if BLOCK_M_SIZE_MAX >= 2
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
if
(
m_count
==
2
)
return
gemm_half_q_half_gptq_kernel
<
true
,
2
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_gptq_kernel
<
true
,
3
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_gptq_kernel
<
true
,
4
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_gptq_kernel
<
true
,
5
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_gptq_kernel
<
true
,
6
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_gptq_kernel
<
true
,
7
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_gptq_kernel
<
true
,
8
>
;
#endif
return
NULL
;
}
int
t
=
threadIdx
.
x
;
void
gemm_half_q_half_cuda_part
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
{
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
;
if
(
b_q_perm
)
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
else
a0
=
a_ptr
[
offset_k
+
t
];
block_a_ptr
[
t
]
=
a0
;
}
}
// Zero output
if
(
n
>=
size_n
)
return
;
if
(
blockIdx
.
z
==
0
)
{
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
*
((
uint64_t
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
))
=
0
;
}
__syncthreads
();
// Find initial group
int
groupsize
=
size_k
/
groups
;
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// a, b offset
int
qk
=
offset_k
/
(
32
/
2
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
// Initial group
int
zeros
[
4
];
half
scales
[
4
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4
(
scales
,
group
,
n
);
// Column result
half
block_c
[
m_count
][
4
]
=
{};
// Dequantize and multiply
int
k
=
offset_k
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4
(
scales
,
group
,
n
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
1
;
j
++
)
{
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
half2
dq
[
4
][
8
];
dequant_2bit_16
(
load_int4
.
x
,
dq
[
0
],
size_n
,
zeros
[
0
]
+
1
);
dequant_2bit_16
(
load_int4
.
y
,
dq
[
1
],
size_n
,
zeros
[
1
]
+
1
);
dequant_2bit_16
(
load_int4
.
z
,
dq
[
2
],
size_n
,
zeros
[
2
]
+
1
);
dequant_2bit_16
(
load_int4
.
w
,
dq
[
3
],
size_n
,
zeros
[
3
]
+
1
);
#pragma unroll
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_16_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
scales
[
0
]);
block_c
[
m
][
1
]
=
dot22_16_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
scales
[
1
]);
block_c
[
m
][
2
]
=
dot22_16_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
scales
[
2
]);
block_c
[
m
][
3
]
=
dot22_16_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
scales
[
3
]);
}
b_ptr
+=
size_n
;
a_ptr
+=
16
;
}
k
+=
16
;
}
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
block_c
[
m
][
0
],
block_c
[
m
][
1
]);
half2
result23
=
__halves2half2
(
block_c
[
m
][
2
],
block_c
[
m
][
3
]);
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
}
}
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_gptq_3bit_kernel
(
(
const
half
*
a
,
const
half
*
__restrict__
a
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
b_gptq_qzeros
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
b_gptq_scales
,
const
half
*
__restrict__
b_gptq_scales
,
const
int
*
b_q_perm
,
half
*
__restrict__
c
,
half
*
c
,
const
int
size_m
,
int
size_m
,
const
int
size_n
,
int
size_n
,
const
int
size_k
,
int
size_k
,
const
int
groups
,
int
m_count
,
const
int
*
__restrict__
b_q_perm
int
groups
)
)
{
{
dim3
blockDim
,
gridDim
;
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
blockDim
.
x
=
BLOCK_KN_SIZE
;
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
blockDim
.
y
=
1
;
MatrixView_q3_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
blockDim
.
z
=
1
;
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
gridDim
.
x
=
DIVIDE
(
size_n
,
BLOCK_KN_SIZE
*
4
);
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
fp_gemm_half_q_half_gptq_kernel
kernel
=
pick_gemm_half_q_half_gptq_kernel
(
true
,
m_count
)
;
int
t
=
threadIdx
.
x
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// Block
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
(
int
offset_m
=
blockIdx
.
y
*
m_count
;
a
,
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
b_q_weight
,
b_gptq_qzeros
,
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
b_gptq_scales
,
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
c
,
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
size_m
,
size_n
,
int
n
=
offset_n
+
t
*
4
;
size_k
,
groups
,
// Preload block_a
b_q_perm
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
);
}
if
(
offset_k
+
t
<
end_k
)
{
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
;
if
(
b_q_perm
)
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
else
a0
=
a_ptr
[
offset_k
+
t
];
block_a_ptr
[
t
]
=
a0
;
}
}
// Zero output
if
(
n
>=
size_n
)
return
;
if
(
blockIdx
.
z
==
0
)
{
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
*
((
uint64_t
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
))
=
0
;
}
__syncthreads
();
// Find initial group
int
groupsize
=
size_k
/
groups
;
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// a, b offset
int
qk
=
offset_k
/
32
*
3
;
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
// Initial group
int
zeros
[
4
];
half
scales
[
4
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4
(
scales
,
group
,
n
);
// Column result
half
block_c
[
m_count
][
4
]
=
{};
// Dequantize and multiply
int
k
=
offset_k
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4
(
scales
,
group
,
n
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
1
;
j
++
)
{
int4
load_int4
[
3
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
2
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
16
];
dequant_3bit_32
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
load_int4
[
2
].
x
,
dq
[
0
],
size_n
,
zeros
[
0
]
+
1
);
dequant_3bit_32
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
load_int4
[
2
].
y
,
dq
[
1
],
size_n
,
zeros
[
1
]
+
1
);
dequant_3bit_32
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
load_int4
[
2
].
z
,
dq
[
2
],
size_n
,
zeros
[
2
]
+
1
);
dequant_3bit_32
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
load_int4
[
2
].
w
,
dq
[
3
],
size_n
,
zeros
[
3
]
+
1
);
#pragma unroll
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_32_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
scales
[
0
]);
block_c
[
m
][
1
]
=
dot22_32_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
scales
[
1
]);
block_c
[
m
][
2
]
=
dot22_32_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
scales
[
2
]);
block_c
[
m
][
3
]
=
dot22_32_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
scales
[
3
]);
}
a_ptr
+=
32
;
}
k
+=
32
;
}
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
block_c
[
m
][
0
],
block_c
[
m
][
1
]);
half2
result23
=
__halves2half2
(
block_c
[
m
][
2
],
block_c
[
m
][
3
]);
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
}
}
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_gptq_8bit_kernel
(
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_n
,
const
int
size_k
,
const
int
groups
,
const
int
*
__restrict__
b_q_perm
)
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q8_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
{
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
;
if
(
b_q_perm
)
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
else
a0
=
a_ptr
[
offset_k
+
t
];
block_a_ptr
[
t
]
=
a0
;
}
}
// Zero output
if
(
n
>=
size_n
)
return
;
if
(
blockIdx
.
z
==
0
)
{
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
*
((
uint64_t
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
))
=
0
;
}
__syncthreads
();
// Find initial group
int
groupsize
=
size_k
/
groups
;
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// a, b offset
int
qk
=
offset_k
/
(
32
/
8
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
// Initial group
int
zeros
[
4
];
half
scales
[
4
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4
(
scales
,
group
,
n
);
// Column result
half
block_c
[
m_count
][
4
]
=
{};
// Dequantize and multiply
int
k
=
offset_k
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4
(
scales
,
group
,
n
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int4
load_int4
[
2
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
4
];
dequant_8bit_8
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
dq
[
0
],
size_n
,
zeros
[
0
]
+
1
);
dequant_8bit_8
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
dq
[
1
],
size_n
,
zeros
[
1
]
+
1
);
dequant_8bit_8
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
dq
[
2
],
size_n
,
zeros
[
2
]
+
1
);
dequant_8bit_8
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
dq
[
3
],
size_n
,
zeros
[
3
]
+
1
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_8_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
scales
[
0
]);
block_c
[
m
][
1
]
=
dot22_8_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
scales
[
1
]);
block_c
[
m
][
2
]
=
dot22_8_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
scales
[
2
]);
block_c
[
m
][
3
]
=
dot22_8_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
scales
[
3
]);
}
a_ptr
+=
8
;
}
k
+=
32
;
}
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
block_c
[
m
][
0
],
block_c
[
m
][
1
]);
half2
result23
=
__halves2half2
(
block_c
[
m
][
2
],
block_c
[
m
][
3
]);
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
}
}
fp_gemm_half_q_half_gptq_kernel
pick_gemm_half_q_half_gptq_kernel
(
bool
first_block
,
const
int
m_count
,
const
int
bit
)
{
#define SELECT_KERNEL(M_COUNT) \
if (m_count == M_COUNT) { \
if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel<true, M_COUNT>; \
}
#if BLOCK_M_SIZE_MAX >= 1
SELECT_KERNEL
(
1
);
#endif
#if BLOCK_M_SIZE_MAX >= 2
SELECT_KERNEL
(
2
);
#endif
#if BLOCK_M_SIZE_MAX >= 3
SELECT_KERNEL
(
3
);
#endif
#if BLOCK_M_SIZE_MAX >= 4
SELECT_KERNEL
(
4
);
#endif
#if BLOCK_M_SIZE_MAX >= 5
SELECT_KERNEL
(
5
);
#endif
#if BLOCK_M_SIZE_MAX >= 6
SELECT_KERNEL
(
6
);
#endif
#if BLOCK_M_SIZE_MAX >= 7
SELECT_KERNEL
(
7
);
#endif
#if BLOCK_M_SIZE_MAX >= 8
SELECT_KERNEL
(
8
);
#endif
return
NULL
;
}
void
gemm_half_q_half_cuda_part
(
const
half
*
a
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_gptq_qzeros
,
const
half
*
b_gptq_scales
,
const
int
*
b_q_perm
,
half
*
c
,
int
size_m
,
int
size_n
,
int
size_k
,
int
m_count
,
int
groups
,
int
bit
)
{
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
blockDim
.
z
=
1
;
gridDim
.
x
=
DIVIDE
(
size_n
,
BLOCK_KN_SIZE
*
4
);
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
fp_gemm_half_q_half_gptq_kernel
kernel
=
pick_gemm_half_q_half_gptq_kernel
(
true
,
m_count
,
bit
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
c
,
size_m
,
size_n
,
size_k
,
groups
,
b_q_perm
);
}
__global__
void
reconstruct_exllama_8bit_kernel
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
const
int
size_k
,
const
int
size_n
,
const
int
groups
,
half
*
__restrict__
b
)
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q8_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
// Preload remapping table
__shared__
int
perm
[
BLOCK_KN_SIZE
];
int
t
=
threadIdx
.
x
;
if
(
b_q_perm
)
{
if
(
offset_k
+
t
<
size_k
)
perm
[
t
]
=
b_q_perm
[
offset_k
+
t
];
}
// Column
int
n
=
offset_n
+
t
*
4
;
if
(
n
>=
size_n
)
return
;
// Find initial group
int
groupsize
=
size_k
/
groups
;
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// b offset
int
qk
=
offset_k
/
(
32
/
8
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
// Initial zeros/scale
int
zeros
[
4
];
half2
scales
[
4
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
__syncthreads
();
int
k
=
offset_k
;
int
lk
=
0
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
{
int4
load_int4
[
2
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
4
];
dequant_8bit_8
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
dq
[
0
],
size_n
,
zeros
[
0
]
+
1
);
dequant_8bit_8
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
dq
[
1
],
size_n
,
zeros
[
1
]
+
1
);
dequant_8bit_8
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
dq
[
2
],
size_n
,
zeros
[
2
]
+
1
);
dequant_8bit_8
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
dq
[
3
],
size_n
,
zeros
[
3
]
+
1
);
//half* dqh = (half*)dq;
if
(
b_q_perm
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
perm
[
lk
++
],
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
perm
[
lk
++
],
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
else
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
}
k
+=
32
;
}
}
__global__
void
reconstruct_exllama_4bit_kernel
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
const
int
size_k
,
const
int
size_n
,
const
int
groups
,
half
*
__restrict__
b
)
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q4_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
// Preload remapping table
__shared__
int
perm
[
BLOCK_KN_SIZE
];
int
t
=
threadIdx
.
x
;
if
(
b_q_perm
)
{
if
(
offset_k
+
t
<
size_k
)
perm
[
t
]
=
b_q_perm
[
offset_k
+
t
];
}
// Column
int
n
=
offset_n
+
t
*
4
;
if
(
n
>=
size_n
)
return
;
// Find initial group
int
groupsize
=
size_k
/
groups
;
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// b offset
int
qk
=
offset_k
/
(
32
/
4
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
// Initial zeros/scale
int
zeros
[
4
];
half2
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
3
]
+
1
,
z1z16
[
3
],
y1y16
[
3
]);
__syncthreads
();
int
k
=
offset_k
;
int
lk
=
0
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
3
]
+
1
,
z1z16
[
3
],
y1y16
[
3
]);
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
{
half2
dq
[
4
][
4
];
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
dequant_4bit_8_gptq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
b_ptr
+=
size_n
;
//half* dqh = (half*)dq;
if
(
b_q_perm
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
perm
[
lk
++
],
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
perm
[
lk
++
],
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
else
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
}
k
+=
32
;
}
}
__global__
void
reconstruct_exllama_3bit_kernel
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
const
int
size_k
,
const
int
size_n
,
const
int
groups
,
half
*
__restrict__
b
)
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q3_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
// Preload remapping table
__shared__
int
perm
[
BLOCK_KN_SIZE
];
int
t
=
threadIdx
.
x
;
if
(
b_q_perm
)
{
if
(
offset_k
+
t
<
size_k
)
perm
[
t
]
=
b_q_perm
[
offset_k
+
t
];
}
// Column
int
n
=
offset_n
+
t
*
4
;
if
(
n
>=
size_n
)
return
;
// Find initial group
int
groupsize
=
size_k
/
groups
;
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// b offset
int
qk
=
offset_k
/
32
*
3
;
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
// Initial zeros/scale
int
zeros
[
4
];
half2
scales
[
4
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
__syncthreads
();
int
k
=
offset_k
;
int
lk
=
0
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
}
for
(
int
p
=
0
;
p
<
1
;
p
++
)
{
int4
load_int4
[
3
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
2
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
16
];
dequant_3bit_32
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
load_int4
[
2
].
x
,
dq
[
0
],
size_n
,
zeros
[
0
]
+
1
);
dequant_3bit_32
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
load_int4
[
2
].
y
,
dq
[
1
],
size_n
,
zeros
[
1
]
+
1
);
dequant_3bit_32
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
load_int4
[
2
].
z
,
dq
[
2
],
size_n
,
zeros
[
2
]
+
1
);
dequant_3bit_32
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
load_int4
[
2
].
w
,
dq
[
3
],
size_n
,
zeros
[
3
]
+
1
);
if
(
b_q_perm
)
{
for
(
int
j
=
0
;
j
<
16
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
perm
[
lk
++
],
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
perm
[
lk
++
],
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
else
{
for
(
int
j
=
0
;
j
<
16
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
}
k
+=
32
;
}
}
__global__
void
reconstruct_exllama_kernel
__global__
void
reconstruct_exllama_
2bit_
kernel
(
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
int
*
__restrict__
b_q_perm
,
...
@@ -317,7 +1136,7 @@ __global__ void reconstruct_exllama_kernel
...
@@ -317,7 +1136,7 @@ __global__ void reconstruct_exllama_kernel
)
)
{
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q
4
_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_q
2
_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
...
@@ -345,21 +1164,15 @@ __global__ void reconstruct_exllama_kernel
...
@@ -345,21 +1164,15 @@ __global__ void reconstruct_exllama_kernel
int
nextgroup
=
offset_k
+
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// b offset
// b offset
int
qk
=
offset_k
/
(
32
/
4
);
int
qk
=
offset_k
/
(
32
/
2
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
// Initial zeros/scale
// Initial zeros/scale
int
zeros
[
4
];
int
zeros
[
4
];
half2
scales
[
4
];
half2
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
3
]
+
1
,
z1z16
[
3
],
y1y16
[
3
]);
__syncthreads
();
__syncthreads
();
...
@@ -374,28 +1187,24 @@ __global__ void reconstruct_exllama_kernel
...
@@ -374,28 +1187,24 @@ __global__ void reconstruct_exllama_kernel
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
3
]
+
1
,
z1z16
[
3
],
y1y16
[
3
]);
}
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
for
(
int
p
=
0
;
p
<
2
;
p
++
)
{
{
half2
dq
[
4
][
4
];
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
int4
load_int4
=
*
b_ptr4
;
dequant_4bit_8_gptq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
half2
dq
[
4
][
8
];
dequant_4bit_8_gptq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_2bit_16
(
load_int4
.
x
,
dq
[
0
],
size_n
,
zeros
[
0
]
+
1
);
dequant_4bit_8_gptq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_2bit_16
(
load_int4
.
y
,
dq
[
1
],
size_n
,
zeros
[
1
]
+
1
);
dequant_4bit_8_gptq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
dequant_2bit_16
(
load_int4
.
z
,
dq
[
2
],
size_n
,
zeros
[
2
]
+
1
);
dequant_2bit_16
(
load_int4
.
w
,
dq
[
3
],
size_n
,
zeros
[
3
]
+
1
);
b_ptr
+=
size_n
;
b_ptr
+=
size_n
;
//half* dqh = (half*)dq;
//half* dqh = (half*)dq;
if
(
b_q_perm
)
if
(
b_q_perm
)
{
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
perm
[
lk
++
],
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
perm
[
lk
++
],
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
...
@@ -404,7 +1213,7 @@ __global__ void reconstruct_exllama_kernel
...
@@ -404,7 +1213,7 @@ __global__ void reconstruct_exllama_kernel
}
}
else
else
{
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
...
@@ -416,7 +1225,6 @@ __global__ void reconstruct_exllama_kernel
...
@@ -416,7 +1225,6 @@ __global__ void reconstruct_exllama_kernel
}
}
}
}
void
reconstruct_exllama
void
reconstruct_exllama
(
(
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_q_weight
,
...
@@ -426,7 +1234,8 @@ void reconstruct_exllama
...
@@ -426,7 +1234,8 @@ void reconstruct_exllama
half
*
out
,
half
*
out
,
int
height
,
int
height
,
int
width
,
int
width
,
int
groups
int
groups
,
int
bit
)
)
{
{
dim3
blockDim
,
gridDim
;
dim3
blockDim
,
gridDim
;
...
@@ -435,6 +1244,15 @@ void reconstruct_exllama
...
@@ -435,6 +1244,15 @@ void reconstruct_exllama
gridDim
.
y
=
DIVIDE
(
height
,
BLOCK_KN_SIZE
);
gridDim
.
y
=
DIVIDE
(
height
,
BLOCK_KN_SIZE
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
auto
reconstruct_exllama_kernel
=
reconstruct_exllama_4bit_kernel
;
if
(
bit
==
2
)
{
reconstruct_exllama_kernel
=
reconstruct_exllama_2bit_kernel
;
}
else
if
(
bit
==
3
)
{
reconstruct_exllama_kernel
=
reconstruct_exllama_3bit_kernel
;
}
else
if
(
bit
==
8
)
{
reconstruct_exllama_kernel
=
reconstruct_exllama_8bit_kernel
;
}
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
reconstruct_exllama_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
reconstruct_exllama_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
...
@@ -450,7 +1268,7 @@ void reconstruct_exllama
...
@@ -450,7 +1268,7 @@ void reconstruct_exllama
}
}
__global__
void
gemm_half_q_half_alt_kernel
(
__global__
void
gemm_half_q_half_alt_
4bit_
kernel
(
const
half2
*
__restrict__
vec
,
const
half2
*
__restrict__
vec
,
const
uint32_t
*
__restrict__
mat
,
const
uint32_t
*
__restrict__
mat
,
half
*
__restrict__
mul
,
half
*
__restrict__
mul
,
...
@@ -548,6 +1366,95 @@ __global__ void gemm_half_q_half_alt_kernel(
...
@@ -548,6 +1366,95 @@ __global__ void gemm_half_q_half_alt_kernel(
}
}
__global__
void
gemm_half_q_half_alt_8bit_kernel
(
const
half2
*
__restrict__
vec
,
const
uint32_t
*
__restrict__
mat
,
half
*
__restrict__
mul
,
const
half
*
__restrict__
scales
,
const
uint32_t
*
__restrict__
zeros
,
const
int
*
__restrict__
g_idx
,
int
batch
,
int
height
,
int
width
)
{
int
zero_width
=
width
/
4
;
int
vec_height
=
height
*
2
;
const
int
blockwidth2
=
BLOCK_KN_SIZE
/
2
;
int
b
=
blockIdx
.
y
*
BLOCK_M_SIZE_MAX
;
int
b_end
=
min
(
BLOCK_M_SIZE_MAX
,
batch
-
b
);
int
h
=
BLOCK_KN_SIZE
*
blockIdx
.
z
/
4
;
int
h_end
=
min
(
BLOCK_KN_SIZE
/
4
,
height
-
h
)
*
2
;
int
w
=
BLOCK_KN_SIZE
*
blockIdx
.
x
+
threadIdx
.
x
;
__shared__
half2
blockvec
[
BLOCK_M_SIZE_MAX
][
blockwidth2
];
if
(
threadIdx
.
x
<
h_end
)
{
for
(
int
m
=
0
;
m
<
b_end
;
++
m
)
{
blockvec
[
m
][
threadIdx
.
x
]
=
vec
[(
m
+
b
)
*
vec_height
+
blockIdx
.
z
*
BLOCK_KN_SIZE
/
2
+
threadIdx
.
x
];
}
}
if
(
blockIdx
.
z
==
0
)
{
for
(
int
m
=
0
;
m
<
b_end
;
m
++
)
mul
[(
b
+
m
)
*
width
+
w
]
=
__int2half_rn
(
0
);
}
__syncthreads
();
int
i
=
width
*
h
+
w
;
int
g_h
=
h
*
4
;
int
k
=
0
;
int
z_w
=
w
/
4
;
int
z_mod
=
(
w
%
4
)
*
8
;
half2
res2
;
half
res
[
BLOCK_M_SIZE_MAX
]
=
{};
unsigned
int
tmp
;
while
(
k
<
h_end
)
{
tmp
=
mat
[
i
];
half2
scales_tmp
[
2
];
half2
zeros_tmp
[
2
];
for
(
int
tmp_k
=
0
;
tmp_k
<
2
;
tmp_k
++
)
{
int
g
=
g_idx
[
g_h
+
(
k
+
tmp_k
)
*
2
];
int
g2
=
g_idx
[
g_h
+
(
k
+
tmp_k
)
*
2
+
1
];
half
scale_f
=
scales
[
g
*
width
+
w
];
half
scale_f2
=
scales
[
g2
*
width
+
w
];
half2
scale
=
__halves2half2
(
scale_f
,
scale_f2
);
half2
zero
=
__halves2half2
(
__hmul
(
scale_f
,
__int2half_rn
(
-
((
zeros
[
g
*
zero_width
+
z_w
]
>>
z_mod
)
&
0xff
)
-
1
)),
__hmul
(
scale_f2
,
__int2half_rn
(
-
((
zeros
[
g2
*
zero_width
+
z_w
]
>>
z_mod
)
&
0xff
)
-
1
))
);
scales_tmp
[
tmp_k
]
=
scale
;
zeros_tmp
[
tmp_k
]
=
zero
;
}
for
(
int
m
=
0
;
m
<
b_end
;
m
++
)
{
#ifndef USE_ROCM
res2
=
{};
#else
res2
.
x
=
__half_as_ushort
(
__float2half
(
0
));
res2
.
y
=
__half_as_ushort
(
__float2half
(
0
));
#endif
half2
v12
=
__halves2half2
(
__int2half_rn
(
tmp
&
0xFF
),
__int2half_rn
((
tmp
>>
8
)
&
0xFF
));
res2
=
__hfma2
(
__hfma2
(
v12
,
scales_tmp
[
0
],
zeros_tmp
[
0
]),
blockvec
[
m
][
k
+
0
],
res2
);
half2
v34
=
__halves2half2
(
__int2half_rn
((
tmp
>>
16
)
&
0xFF
),
__int2half_rn
((
tmp
>>
24
)
&
0xFF
));
res2
=
__hfma2
(
__hfma2
(
v34
,
scales_tmp
[
1
],
zeros_tmp
[
1
]),
blockvec
[
m
][
k
+
1
],
res2
);
#ifndef USE_ROCM
res
[
m
]
=
__hadd
(
res
[
m
],
__hadd
(
res2
.
x
,
res2
.
y
));
#else
res
[
m
]
=
__hadd
(
res
[
m
],
__hadd
(
__ushort_as_half
(
res2
.
x
),
__ushort_as_half
(
res2
.
y
)));
#endif
}
i
+=
width
;
k
+=
2
;
}
for
(
int
m
=
0
;
m
<
b_end
;
m
++
)
{
atomicAdd
(
&
mul
[(
b
+
m
)
*
width
+
w
],
res
[
m
]);
}
}
void
gemm_half_q_half_alt
void
gemm_half_q_half_alt
(
(
const
half
*
a
,
const
half
*
a
,
...
@@ -558,7 +1465,8 @@ void gemm_half_q_half_alt
...
@@ -558,7 +1465,8 @@ void gemm_half_q_half_alt
half
*
c
,
half
*
c
,
int
size_m
,
int
size_m
,
int
size_n
,
int
size_n
,
int
size_k
int
size_k
,
int
bit
)
)
{
{
dim3
blockDim
,
gridDim
;
dim3
blockDim
,
gridDim
;
...
@@ -569,8 +1477,13 @@ void gemm_half_q_half_alt
...
@@ -569,8 +1477,13 @@ void gemm_half_q_half_alt
gridDim
.
y
=
DIVIDE
(
size_m
,
BLOCK_M_SIZE_MAX
);
gridDim
.
y
=
DIVIDE
(
size_m
,
BLOCK_M_SIZE_MAX
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
auto
kernel
=
gemm_half_q_half_alt_4bit_kernel
;
if
(
bit
==
8
)
{
kernel
=
gemm_half_q_half_alt_8bit_kernel
;
}
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
gemm_half_q_half_alt_
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
(
const
half2
*
)
a
,
(
const
half2
*
)
a
,
b_q_weight
,
b_q_weight
,
...
@@ -579,12 +1492,12 @@ void gemm_half_q_half_alt
...
@@ -579,12 +1492,12 @@ void gemm_half_q_half_alt
b_gptq_qzeros
,
b_gptq_qzeros
,
b_g_idx
,
b_g_idx
,
size_m
,
size_m
,
size_k
/
8
,
size_k
/
32
*
bit
,
size_n
size_n
);
);
}
}
template
<
class
T
,
int
bit
>
__global__
void
reconstruct_gptq_kernel
__global__
void
reconstruct_gptq_kernel
(
(
const
uint32_t
*
__restrict__
w
,
const
uint32_t
*
__restrict__
w
,
...
@@ -600,30 +1513,79 @@ __global__ void reconstruct_gptq_kernel
...
@@ -600,30 +1513,79 @@ __global__ void reconstruct_gptq_kernel
// Start of block
// Start of block
int
column
=
BLOCK_KN_SIZE
*
blockIdx
.
x
+
threadIdx
.
x
;
int
column
=
BLOCK_KN_SIZE
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row
=
blockIdx
.
y
*
8
;
int
row
=
blockIdx
.
y
*
32
/
bit
;
if
(
column
>=
width
)
return
;
if
(
column
>=
width
)
return
;
// Views
// Views
MatrixView_q4_column
w_
(
w
,
height
,
width
);
MatrixView_half_rw
out_
(
out
,
height
,
width
);
MatrixView_half_rw
out_
(
out
,
height
,
width
);
MatrixView_half
w_scales_
(
w_scales
,
group
,
width
);
MatrixView_half
w_scales_
(
w_scales
,
group
,
width
);
MatrixView_q4_row
w_zeros_
(
w_zeros
,
group
,
width
);
T
w_zeros_
(
w_zeros
,
group
,
width
);
uint32_t
w_read
=
w
_
.
item_uint32_t
(
row
,
column
)
;
uint32_t
w_read
=
w
[
blockIdx
.
y
*
width
+
column
]
;
half
*
out_ptr
=
out_
.
item_ptr
(
row
,
column
);
half
*
out_ptr
=
out_
.
item_ptr
(
row
,
column
);
#pragma unroll
#pragma unroll
for
(
int
s
=
0
;
s
<
32
;
s
+=
4
)
for
(
int
s
=
0
;
s
<
32
;
s
+=
bit
)
{
{
int
group
=
g_idx
[
row
+
s
/
4
];
int
group
=
g_idx
[
row
+
s
/
bit
];
half
w_scale
=
w_scales_
.
item
(
group
,
column
);
half
w_scale
=
w_scales_
.
item
(
group
,
column
);
uint32_t
w_zero
=
w_zeros_
.
item
(
group
,
column
)
+
1
;
uint32_t
w_zero
=
w_zeros_
.
item
(
group
,
column
)
+
1
;
half
w_item
=
__hmul
(
__int2half_rn
((
int
)((
w_read
>>
s
)
&
0x0f
)
-
w_zero
),
w_scale
);
half
w_item
=
__hmul
(
__int2half_rn
((
int
)((
w_read
>>
s
)
&
((
1
<<
bit
)
-
1
)
)
-
w_zero
),
w_scale
);
*
out_ptr
=
w_item
;
out_ptr
+=
out_
.
width
;
*
out_ptr
=
w_item
;
out_ptr
+=
out_
.
width
;
}
}
}
}
__global__
void
reconstruct_gptq_3bit_kernel
(
const
uint32_t
*
__restrict__
w
,
const
half
*
__restrict__
w_scales
,
const
uint32_t
*
__restrict__
w_zeros
,
const
int
*
__restrict__
g_idx
,
const
int
height
,
const
int
width
,
const
int
group
,
half
*
__restrict__
out
)
{
// Start of block
int
column
=
BLOCK_KN_SIZE
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row
=
blockIdx
.
y
*
32
;
if
(
column
>=
width
)
return
;
// Views
MatrixView_half_rw
out_
(
out
,
height
,
width
);
MatrixView_half
w_scales_
(
w_scales
,
group
,
width
);
MatrixView_q3_row
w_zeros_
(
w_zeros
,
group
,
width
);
uint32_t
w1
=
w
[(
blockIdx
.
y
*
3
)
*
width
+
column
];
uint32_t
w2
=
w
[(
blockIdx
.
y
*
3
+
1
)
*
width
+
column
];
uint32_t
w3
=
w
[(
blockIdx
.
y
*
3
+
2
)
*
width
+
column
];
half
*
out_ptr
=
out_
.
item_ptr
(
row
,
column
);
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
i
+=
1
)
{
int
group
=
g_idx
[
row
+
i
];
half
w_scale
=
w_scales_
.
item
(
group
,
column
);
uint32_t
w_zero
=
w_zeros_
.
item
(
group
,
column
)
+
1
;
int
w_item
;
if
(
i
==
10
)
{
w_item
=
(
w1
>>
30
)
|
((
w2
<<
2
)
&
0x4
);
}
else
if
(
i
==
21
)
{
w_item
=
(
w2
>>
31
)
|
((
w3
<<
1
)
&
0x6
);
}
else
if
(
i
<
10
)
{
w_item
=
((
w1
>>
(
i
*
3
))
&
0x7
);
}
else
if
(
i
<
21
)
{
w_item
=
((
w2
>>
(
i
*
3
-
32
))
&
0x7
);
}
else
{
w_item
=
((
w3
>>
(
i
*
3
-
64
))
&
0x7
);
}
*
out_ptr
=
__hmul
(
__int2half_rn
(
w_item
-
w_zero
),
w_scale
);
out_ptr
+=
out_
.
width
;
}
}
void
reconstruct_gptq
void
reconstruct_gptq
(
(
...
@@ -634,16 +1596,28 @@ void reconstruct_gptq
...
@@ -634,16 +1596,28 @@ void reconstruct_gptq
half
*
out
,
half
*
out
,
int
height
,
int
height
,
int
width
,
int
width
,
int
groups
int
groups
,
int
bit
)
)
{
{
dim3
blockDim
,
gridDim
;
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
blockDim
.
y
=
1
;
gridDim
.
y
=
DIVIDE
(
height
,
8
);
gridDim
.
y
=
DIVIDE
(
height
,
32
/
bit
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
auto
kernel
=
reconstruct_gptq_kernel
<
MatrixView_q4_row
,
4
>
;
if
(
bit
==
2
)
{
kernel
=
reconstruct_gptq_kernel
<
MatrixView_q2_row
,
2
>
;
}
else
if
(
bit
==
8
)
{
kernel
=
reconstruct_gptq_kernel
<
MatrixView_q8_row
,
8
>
;
}
else
if
(
bit
==
3
)
{
kernel
=
reconstruct_gptq_3bit_kernel
;
gridDim
.
y
=
DIVIDE
(
height
,
32
);
}
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
reconstruct_gptq_
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
b_q_weight
,
b_q_weight
,
b_gptq_scales
,
b_gptq_scales
,
...
@@ -671,19 +1645,27 @@ void gemm_half_q_half_cuda
...
@@ -671,19 +1645,27 @@ void gemm_half_q_half_cuda
int
size_n
,
int
size_n
,
int
size_k
,
int
size_k
,
int
groups
,
int
groups
,
bool
use_exllama
bool
use_exllama
,
int
bit
)
)
{
{
if
((
use_exllama
&&
size_m
>
MAX_Q_GEMM_ROWS
)
||
(
!
use_exllama
&&
size_m
>
MAX_ALT_GEMM_ROWS
))
{
bool
use_reconstruct
;
if
(
use_exllama
)
{
use_reconstruct
=
((
bit
==
8
&&
size_m
>
MAX_Q_GEMM_ROWS_8BIT
)
||
(
bit
!=
8
&&
size_m
>
MAX_Q_GEMM_ROWS
));
}
else
{
// The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now.
use_reconstruct
=
(
bit
<
4
||
size_m
>
MAX_ALT_GEMM_ROWS
);
}
if
(
use_reconstruct
)
{
// Reconstruct FP16 matrix, then cuBLAS
// Reconstruct FP16 matrix, then cuBLAS
if
(
use_exllama
)
{
if
(
use_exllama
)
{
reconstruct_exllama
(
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
temp_dq
,
reconstruct_exllama
(
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
temp_dq
,
size_k
,
size_n
,
groups
);
size_k
,
size_n
,
groups
,
bit
);
}
}
else
else
{
{
reconstruct_gptq
(
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
reconstruct_gptq
(
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
temp_dq
,
size_k
,
size_n
,
groups
);
temp_dq
,
size_k
,
size_n
,
groups
,
bit
);
}
}
const
half
alpha
=
__float2half
(
1.0
f
);
const
half
alpha
=
__float2half
(
1.0
f
);
...
@@ -707,7 +1689,7 @@ void gemm_half_q_half_cuda
...
@@ -707,7 +1689,7 @@ void gemm_half_q_half_cuda
{
{
gemm_half_q_half_cuda_part
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
gemm_half_q_half_cuda_part
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
c
,
last_chunk
,
size_n
,
size_k
,
BLOCK_M_SIZE_MAX
,
c
,
last_chunk
,
size_n
,
size_k
,
BLOCK_M_SIZE_MAX
,
groups
);
groups
,
bit
);
}
}
if
(
last_chunk_size
)
if
(
last_chunk_size
)
...
@@ -715,18 +1697,17 @@ void gemm_half_q_half_cuda
...
@@ -715,18 +1697,17 @@ void gemm_half_q_half_cuda
gemm_half_q_half_cuda_part
(
a
+
last_chunk
*
size_k
,
b_q_weight
,
b_gptq_qzeros
,
gemm_half_q_half_cuda_part
(
a
+
last_chunk
*
size_k
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
c
+
last_chunk
*
size_n
,
b_gptq_scales
,
b_g_idx
,
c
+
last_chunk
*
size_n
,
last_chunk_size
,
size_n
,
size_k
,
last_chunk_size
,
last_chunk_size
,
size_n
,
size_k
,
last_chunk_size
,
groups
);
groups
,
bit
);
}
}
}
}
else
else
{
{
gemm_half_q_half_alt
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
gemm_half_q_half_alt
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
c
,
size_m
,
size_n
,
size_k
);
c
,
size_m
,
size_n
,
size_k
,
bit
);
}
}
}
}
__global__
void
shuffle_4bit_kernel
__global__
void
shuffle_kernel
(
(
uint32_t
*
__restrict__
b_q_weight
,
uint32_t
*
__restrict__
b_q_weight
,
const
int
size_k
,
const
int
size_k
,
...
@@ -740,13 +1721,53 @@ __global__ void shuffle_kernel
...
@@ -740,13 +1721,53 @@ __global__ void shuffle_kernel
while
(
k
<
size_k
)
{
shuffle_4bit_8
(
b_ptr
,
size_n
);
b_ptr
+=
1
*
size_n
;
k
+=
8
;
}
while
(
k
<
size_k
)
{
shuffle_4bit_8
(
b_ptr
,
size_n
);
b_ptr
+=
1
*
size_n
;
k
+=
8
;
}
}
}
__global__
void
shuffle_8bit_kernel
(
uint32_t
*
__restrict__
b_q_weight
,
const
int
size_k
,
const
int
size_n
)
{
int
n
=
blockIdx
.
x
*
THREADS_X
+
threadIdx
.
x
;
if
(
n
>=
size_n
)
return
;
int
k
=
0
;
uint32_t
*
b_ptr
=
b_q_weight
+
n
;
while
(
k
<
size_k
)
{
shuffle_8bit_4
(
b_ptr
,
size_n
);
b_ptr
+=
1
*
size_n
;
k
+=
4
;
}
}
__global__
void
shuffle_2bit_kernel
(
uint32_t
*
__restrict__
b_q_weight
,
const
int
size_k
,
const
int
size_n
)
{
int
n
=
blockIdx
.
x
*
THREADS_X
+
threadIdx
.
x
;
if
(
n
>=
size_n
)
return
;
int
k
=
0
;
uint32_t
*
b_ptr
=
b_q_weight
+
n
;
while
(
k
<
size_k
)
{
shuffle_2bit_16
(
b_ptr
,
size_n
);
b_ptr
+=
1
*
size_n
;
k
+=
16
;
}
}
__global__
void
shuffle_3bit_kernel
(
uint32_t
*
__restrict__
b_q_weight
,
const
int
size_k
,
const
int
size_n
)
{
int
n
=
blockIdx
.
x
*
THREADS_X
+
threadIdx
.
x
;
if
(
n
>=
size_n
)
return
;
int
k
=
0
;
uint32_t
*
b_ptr
=
b_q_weight
+
n
;
while
(
k
<
size_k
)
{
shuffle_3bit_32
(
b_ptr
,
size_n
);
b_ptr
+=
3
*
size_n
;
k
+=
32
;
}
}
__global__
void
make_sequential_kernel
__global__
void
make_sequential_
4bit_
kernel
(
(
const
uint32_t
*
__restrict__
w
,
const
uint32_t
*
__restrict__
w
,
uint32_t
*
__restrict__
w_new
,
uint32_t
*
__restrict__
w_new
,
const
int
*
__restrict__
q_perm
,
const
int
*
__restrict__
q_perm
,
const
int
w_height
,
const
int
w_width
const
int
w_width
)
)
{
{
...
@@ -778,37 +1799,204 @@ __global__ void make_sequential_kernel
...
@@ -778,37 +1799,204 @@ __global__ void make_sequential_kernel
w_new2
[
w_new2_row
*
w2_stride
+
w2_column
]
=
dst
;
w_new2
[
w_new2_row
*
w2_stride
+
w2_column
]
=
dst
;
}
}
__global__
void
make_sequential_2bit_kernel
(
const
uint32_t
*
__restrict__
w
,
uint32_t
*
__restrict__
w_new
,
const
int
*
__restrict__
q_perm
,
const
int
w_width
)
{
const
uint64_t
*
w2
=
(
uint64_t
*
)
w
;
uint64_t
*
w_new2
=
(
uint64_t
*
)
w_new
;
int
w2_stride
=
w_width
>>
1
;
int
w2_column
=
THREADS_X
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
w2_column
>=
w2_stride
)
return
;
int
w_new2_row
=
blockIdx
.
y
;
int
q_perm_idx
=
w_new2_row
<<
4
;
uint64_t
dst
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
int
source_row
=
q_perm
[
q_perm_idx
++
];
int
w2_row
=
source_row
>>
4
;
int
w2_subrow
=
source_row
&
0x0f
;
int
w2_row_shift
=
w2_subrow
<<
1
;
int
wnew2_row_shift
=
i
<<
1
;
uint64_t
src
=
w2
[
w2_row
*
w2_stride
+
w2_column
];
src
>>=
w2_row_shift
;
src
&=
0x0000000300000003
;
src
<<=
wnew2_row_shift
;
dst
|=
src
;
}
w_new2
[
w_new2_row
*
w2_stride
+
w2_column
]
=
dst
;
}
__global__
void
make_sequential_3bit_kernel
(
const
uint32_t
*
__restrict__
w
,
uint32_t
*
__restrict__
w_new
,
const
int
*
__restrict__
q_perm
,
const
int
w_width
)
{
int
w_column
=
THREADS_X
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
w_column
>=
w_width
)
return
;
int
w_new_row
=
blockIdx
.
y
*
3
;
int
q_perm_idx
=
blockIdx
.
y
<<
5
;
uint32_t
dst
[
3
]
=
{
0
,
0
,
0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
int
source_row
=
q_perm
[
q_perm_idx
++
];
int
z_w
=
(
source_row
/
32
)
*
3
;
int
z_mod
=
source_row
%
32
;
int
z_bit
;
if
(
z_mod
!=
10
){
if
(
z_mod
!=
21
){
z_bit
=
z_mod
;
if
(
z_bit
>
21
){
z_bit
*=
3
;
z_bit
-=
64
;
z_w
+=
2
;
}
else
if
(
z_bit
>
10
){
z_bit
*=
3
;
z_bit
-=
32
;
z_w
+=
1
;
}
else
{
z_bit
*=
3
;
}
}
else
{
z_w
+=
1
;
}
}
uint64_t
src
;
if
(
z_mod
==
10
)
{
src
=
(
w
[
z_w
*
w_width
+
w_column
]
>>
30
)
|
((
w
[(
z_w
+
1
)
*
w_width
+
w_column
]
<<
2
)
&
0x4
);
}
else
if
(
z_mod
==
21
){
src
=
(
w
[
z_w
*
w_width
+
w_column
]
>>
31
)
|
((
w
[(
z_w
+
1
)
*
w_width
+
w_column
]
<<
1
)
&
0x6
);
}
else
{
src
=
w
[
z_w
*
w_width
+
w_column
];
src
>>=
z_bit
;
src
&=
0x07
;
}
z_w
=
0
;
if
(
i
!=
10
){
if
(
i
!=
21
){
z_bit
=
i
;
if
(
z_bit
>
21
){
z_bit
*=
3
;
z_bit
-=
64
;
z_w
+=
2
;
}
else
if
(
z_bit
>
10
){
z_bit
*=
3
;
z_bit
-=
32
;
z_w
+=
1
;
}
else
{
z_bit
*=
3
;
}
}
else
{
z_w
+=
1
;
}
}
if
(
i
==
10
)
{
dst
[
z_w
]
|=
(
src
&
0x03
)
<<
30
;
dst
[
z_w
+
1
]
|=
((
src
&
0x4
)
>>
2
);
}
else
if
(
i
==
21
)
{
dst
[
z_w
]
|=
(
src
&
0x01
)
<<
31
;
dst
[
z_w
+
1
]
|=
((
src
&
0x6
)
>>
1
);
}
else
{
dst
[
z_w
]
|=
(
src
<<
z_bit
);
}
}
w_new
[
w_new_row
*
w_width
+
w_column
]
=
dst
[
0
];
w_new
[(
w_new_row
+
1
)
*
w_width
+
w_column
]
=
dst
[
1
];
w_new
[(
w_new_row
+
2
)
*
w_width
+
w_column
]
=
dst
[
2
];
}
__global__
void
make_sequential_8bit_kernel
(
const
uint32_t
*
__restrict__
w
,
uint32_t
*
__restrict__
w_new
,
const
int
*
__restrict__
q_perm
,
const
int
w_width
)
{
const
uint64_t
*
w2
=
(
uint64_t
*
)
w
;
uint64_t
*
w_new2
=
(
uint64_t
*
)
w_new
;
int
w2_stride
=
w_width
>>
1
;
int
w2_column
=
THREADS_X
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
w2_column
>=
w2_stride
)
return
;
int
w_new2_row
=
blockIdx
.
y
;
int
q_perm_idx
=
w_new2_row
<<
2
;
uint64_t
dst
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
source_row
=
q_perm
[
q_perm_idx
++
];
int
w2_row
=
source_row
>>
2
;
int
w2_subrow
=
source_row
&
0x03
;
int
w2_row_shift
=
w2_subrow
<<
3
;
int
wnew2_row_shift
=
i
<<
3
;
uint64_t
src
=
w2
[
w2_row
*
w2_stride
+
w2_column
];
src
>>=
w2_row_shift
;
src
&=
0x000000ff000000ff
;
src
<<=
wnew2_row_shift
;
dst
|=
src
;
}
w_new2
[
w_new2_row
*
w2_stride
+
w2_column
]
=
dst
;
}
void
shuffle_exllama_weight
void
shuffle_exllama_weight
(
(
uint32_t
*
q_weight
,
uint32_t
*
q_weight
,
int
*
q_perm
,
int
*
q_perm
,
int
height
,
int
height
,
int
width
int
width
,
int
bit
)
)
{
{
if
(
q_perm
)
if
(
q_perm
)
{
{
uint32_t
*
new_qweight
=
NULL
;
uint32_t
*
new_qweight
=
NULL
;
cudaMalloc
(
&
new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
));
cudaMalloc
(
&
new_qweight
,
height
/
32
*
bit
*
width
*
sizeof
(
uint32_t
));
dim3
blockDim
,
gridDim
;
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
THREADS_X
;
blockDim
.
x
=
THREADS_X
;
blockDim
.
y
=
1
;
blockDim
.
y
=
1
;
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
height
/
8
;
gridDim
.
y
=
height
/
32
*
bit
;
auto
kernel
=
make_sequential_4bit_kernel
;
if
(
bit
==
2
)
{
kernel
=
make_sequential_2bit_kernel
;
}
else
if
(
bit
==
3
)
{
kernel
=
make_sequential_3bit_kernel
;
gridDim
.
y
=
height
/
32
;
}
else
if
(
bit
==
8
)
{
kernel
=
make_sequential_8bit_kernel
;
}
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
make_sequential_
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
q_weight
,
q_weight
,
new_qweight
,
new_qweight
,
q_perm
,
q_perm
,
height
/
8
,
width
width
);
);
// Replace qweights
// Replace qweights
cudaMemcpyAsync
(
q_weight
,
new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
),
cudaMemcpyDeviceToDevice
);
cudaMemcpyAsync
(
q_weight
,
new_qweight
,
height
/
32
*
bit
*
width
*
sizeof
(
uint32_t
),
cudaMemcpyDeviceToDevice
);
// Cleanup
// Cleanup
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
cudaFree
(
new_qweight
);
cudaFree
(
new_qweight
);
...
@@ -818,6 +2006,14 @@ void shuffle_exllama_weight
...
@@ -818,6 +2006,14 @@ void shuffle_exllama_weight
blockDim
.
y
=
1
;
blockDim
.
y
=
1
;
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
1
;
gridDim
.
y
=
1
;
auto
shuffle_kernel
=
shuffle_4bit_kernel
;
if
(
bit
==
2
)
{
shuffle_kernel
=
shuffle_2bit_kernel
;
}
else
if
(
bit
==
3
)
{
shuffle_kernel
=
shuffle_3bit_kernel
;
}
else
if
(
bit
==
8
)
{
shuffle_kernel
=
shuffle_8bit_kernel
;
}
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
shuffle_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
q_weight
,
height
,
width
);
shuffle_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
q_weight
,
height
,
width
);
}
}
...
@@ -832,13 +2028,14 @@ torch::Tensor gptq_gemm
...
@@ -832,13 +2028,14 @@ torch::Tensor gptq_gemm
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
bool
use_exllama
,
int
bit
)
)
{
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
at
::
Tensor
c
=
torch
::
empty
({
a
.
size
(
0
),
b_q_weight
.
size
(
1
)},
options
);
at
::
Tensor
c
=
torch
::
empty
({
a
.
size
(
0
),
b_q_weight
.
size
(
1
)},
options
);
at
::
Tensor
temp_dq
=
torch
::
empty
({
b_q_weight
.
size
(
0
)
*
8
,
b_q_weight
.
size
(
1
)},
options
);
at
::
Tensor
temp_dq
=
torch
::
empty
({
b_q_weight
.
size
(
0
)
*
32
/
bit
,
b_q_weight
.
size
(
1
)},
options
);
vllm
::
gptq
::
gemm_half_q_half_cuda
vllm
::
gptq
::
gemm_half_q_half_cuda
(
(
...
@@ -854,7 +2051,8 @@ torch::Tensor gptq_gemm
...
@@ -854,7 +2051,8 @@ torch::Tensor gptq_gemm
c
.
size
(
1
),
// n
c
.
size
(
1
),
// n
a
.
size
(
1
),
// k
a
.
size
(
1
),
// k
b_gptq_qzeros
.
size
(
0
),
// group number
b_gptq_qzeros
.
size
(
0
),
// group number
use_exllama
use_exllama
,
bit
);
);
return
c
;
return
c
;
}
}
...
@@ -862,14 +2060,16 @@ torch::Tensor gptq_gemm
...
@@ -862,14 +2060,16 @@ torch::Tensor gptq_gemm
void
gptq_shuffle
void
gptq_shuffle
(
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
torch
::
Tensor
q_perm
,
int
bit
)
)
{
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
vllm
::
gptq
::
shuffle_exllama_weight
(
vllm
::
gptq
::
shuffle_exllama_weight
(
(
uint32_t
*
)
q_weight
.
data_ptr
(),
(
uint32_t
*
)
q_weight
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
?
NULL
:
(
int
*
)
q_perm
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
?
NULL
:
(
int
*
)
q_perm
.
data_ptr
(),
q_weight
.
size
(
0
)
*
8
,
q_weight
.
size
(
0
)
*
32
/
bit
,
q_weight
.
size
(
1
)
q_weight
.
size
(
1
),
bit
);
);
}
}
csrc/quantization/gptq/qdq_2.cuh
0 → 100644
View file @
01a5d18a
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__
__device__
void
shuffle_2bit_16
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x03
;
uint32_t
qa1
=
(
qa
&
0x0c
)
>>
2
;
qa
>>=
4
;
qb
|=
(
qa1
<<
(
i
*
2
+
16
));
qb
|=
(
qa0
<<
(
i
*
2
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_2bit_16
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
8
],
int
stride
,
const
uint32_t
zero
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y4_
=
__float2half_rn
(
1.0
f
/
4.0
f
);
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y4
=
__halves2half2
(
y4_
,
y4_
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half_uint16
z1_
(
0xe400
|
zero
);
// half(-1024.0f - zero);
const
half
z4_
=
__hsub
(
__int2half_rn
(
-
256
),
__int2half_rn
(
zero
));
const
half
z16_
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
const
half
z64_
=
__hsub
(
__int2half_rn
(
-
16
),
__int2half_rn
(
zero
));
const
half2
z1
=
__half2half2
(
z1_
.
as_half
);
const
half2
z4
=
__half2half2
(
z4_
);
const
half2
z16
=
__half2half2
(
z16_
);
const
half2
z64
=
__half2half2
(
z64_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32
q2
((
qa
&
0x00300030
)
|
c0
);
// half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32
q3
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 64 + 1024
qa
>>=
8
;
half2_uint32
q4
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 8], q[ 8]) + 1024
half2_uint32
q5
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[10], q[11]) * 4 + 1024
half2_uint32
q6
((
qa
&
0x00300030
)
|
c0
);
// half2(q[12], q[13]) * 16 + 1024
half2_uint32
q7
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[14], q[15]) * 64 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y4
,
z4
);
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y16
,
z16
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y64
,
z64
);
dq
[
4
]
=
__hadd2
(
q4
.
as_half2
,
z1
);
dq
[
5
]
=
__hfma2
(
q5
.
as_half2
,
y4
,
z4
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y16
,
z16
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y64
,
z64
);
}
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/qdq_3.cuh
0 → 100644
View file @
01a5d18a
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__
__device__
void
shuffle_3bit_32
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
*
stride
];
uint32_t
qb
=
q
[
1
*
stride
];
uint32_t
qc
=
q
[
2
*
stride
];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t
qd
=
qc
>>
26
;
qc
<<=
4
;
qc
|=
qb
>>
28
;
qb
<<=
2
;
qb
|=
qa
>>
30
;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t
za
=
0
;
uint32_t
zb
=
0
;
uint32_t
zc
=
0
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qa
&
0x07
;
uint32_t
t1
=
(
qa
&
0x38
)
>>
3
;
qa
>>=
6
;
za
|=
(
t0
<<
(
i
*
3
));
za
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qb
&
0x07
;
uint32_t
t1
=
(
qb
&
0x38
)
>>
3
;
qb
>>=
6
;
zb
|=
(
t0
<<
(
i
*
3
));
zb
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qc
&
0x07
;
uint32_t
t1
=
(
qc
&
0x38
)
>>
3
;
qc
>>=
6
;
zc
|=
(
t0
<<
(
i
*
3
));
zc
|=
(
t1
<<
(
i
*
3
+
16
));
}
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za
|=
((
qd
&
0x01
)
>>
0
)
<<
15
;
zb
|=
((
qd
&
0x02
)
>>
1
)
<<
15
;
zc
|=
((
qd
&
0x04
)
>>
2
)
<<
15
;
za
|=
((
qd
&
0x08
)
>>
3
)
<<
31
;
zb
|=
((
qd
&
0x10
)
>>
4
)
<<
31
;
zc
|=
((
qd
&
0x20
)
>>
5
)
<<
31
;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q
[
0
*
stride
]
=
za
;
q
[
1
*
stride
]
=
zb
;
q
[
2
*
stride
]
=
zc
;
}
__forceinline__
__device__
void
dequant_3bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
16
],
int
stride
,
const
uint32_t
zero
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y8_
=
__float2half_rn
(
1.0
f
/
8.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y8
=
__halves2half2
(
y8_
,
y8_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half_uint16
z1_
(
0xe400
|
zero
);
// half(-1024.0f - zero);
const
half
z8_
=
__hsub
(
__int2half_rn
(
-
128
),
__int2half_rn
(
zero
));
const
half
z64_
=
__hsub
(
__int2half_rn
(
-
16
),
__int2half_rn
(
zero
));
const
half2
z1
=
__halves2half2
(
z1_
.
as_half
,
z1_
.
as_half
);
const
half2
z8
=
__halves2half2
(
z8_
,
z8_
);
const
half2
z64
=
__halves2half2
(
z64_
,
z64_
);
uint32_t
qa
=
q_0
;
uint32_t
qb
=
q_1
;
uint32_t
qc
=
q_2
;
half2_uint32
q0
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 2], q[ 3]) * 8 + 1024
qa
>>=
6
;
half2_uint32
q2
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32
q4
((
qa
&
0x01c001c0
)
|
c0
);
// half2(q[ 8], q[ 9]) * 64 + 1024
qa
>>=
9
;
qa
&=
0x00010001
;
half2_uint32
q5
((
qb
&
0x00070007
)
|
c0
);
// half2(q[10], q[11]) + 1024
half2_uint32
q6
((
qb
&
0x00380038
)
|
c0
);
// half2(q[12], q[13]) * 8 + 1024
qb
>>=
6
;
half2_uint32
q7
((
qb
&
0x00070007
)
|
c0
);
// half2(q[14], q[15]) + 1024
half2_uint32
q8
((
qb
&
0x00380038
)
|
c0
);
// half2(q[16], q[17]) * 8 + 1024
half2_uint32
q9
((
qb
&
0x01c001c0
)
|
c0
);
// half2(q[18], q[19]) * 64 + 1024
qb
>>=
8
;
qb
&=
0x00020002
;
half2_uint32
q10
((
qc
&
0x00070007
)
|
c0
);
// half2(q[20], q[21]) + 1024
half2_uint32
q11
((
qc
&
0x00380038
)
|
c0
);
// half2(q[22], q[23]) * 8 + 1024
qc
>>=
6
;
half2_uint32
q12
((
qc
&
0x00070007
)
|
c0
);
// half2(q[24], q[25]) + 1024
half2_uint32
q13
((
qc
&
0x00380038
)
|
c0
);
// half2(q[26], q[27]) * 8 + 1024
half2_uint32
q14
((
qc
&
0x01c001c0
)
|
c0
);
// half2(q[28], q[29]) * 64 + 1024
qc
>>=
7
;
qc
&=
0x00040004
;
half2_uint32
q15
((
qa
|
qb
|
qc
)
|
c0
);
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y8
,
z8
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y8
,
z8
);
dq
[
4
]
=
__hfma2
(
q4
.
as_half2
,
y64
,
z64
);
dq
[
5
]
=
__hadd2
(
q5
.
as_half2
,
z1
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y8
,
z8
);
dq
[
7
]
=
__hadd2
(
q7
.
as_half2
,
z1
);
dq
[
8
]
=
__hfma2
(
q8
.
as_half2
,
y8
,
z8
);
dq
[
9
]
=
__hfma2
(
q9
.
as_half2
,
y64
,
z64
);
dq
[
10
]
=
__hadd2
(
q10
.
as_half2
,
z1
);
dq
[
11
]
=
__hfma2
(
q11
.
as_half2
,
y8
,
z8
);
dq
[
12
]
=
__hadd2
(
q12
.
as_half2
,
z1
);
dq
[
13
]
=
__hfma2
(
q13
.
as_half2
,
y8
,
z8
);
dq
[
14
]
=
__hfma2
(
q14
.
as_half2
,
y64
,
z64
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/qdq_4.cuh
View file @
01a5d18a
...
@@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8
...
@@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8
(
(
const
uint32_t
q_0
,
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
dq
)[
4
],
int
stride
int
stride
,
const
uint32_t
zero
)
)
{
{
const
uint32_t
c0
=
0x64006400
;
const
uint32_t
c0
=
0x64006400
;
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half
z1_
=
__float2
half
_rn
(
-
1024.0
f
-
8.0
f
);
const
half
_uint16
z1_
(
0xe400
|
zero
);
//
half(-1024.0f
- zero
);
const
half
z16_
=
__
floa
t2half_rn
(
-
1024.0
f
/
16.0
f
-
8.0
f
);
const
half
z16_
=
__
hsub
(
__in
t2half_rn
(
-
64
),
__int2half_rn
(
zero
)
);
const
half2
z1
=
__hal
ves
2half2
(
z1_
,
z1_
);
const
half2
z1
=
__hal
f
2half2
(
z1_
.
as_half
);
const
half2
z16
=
__hal
ves
2half2
(
z16_
,
z16_
);
const
half2
z16
=
__hal
f
2half2
(
z16_
);
uint32_t
qa
=
q_0
;
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
...
@@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
...
@@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
}
// namespace gptq
}
// namespace gptq
}
// namespace vllm
}
// namespace vllm
#else
namespace
vllm
{
namespace
gptq
{
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
4
,
0x0f
),
8
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z
=
__hmul
(
z
,
scale
);
z1
[
0
]
=
__half2half2
(
z
);
y1
[
0
]
=
__half2half2
(
scale
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z1
[
0
]
=
__half2half2
(
z
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
],
int
stride
,
bool
scaled
)
{
half2
dqh2
[
8
];
uint32_t
qa
=
q_0
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
half
d0
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
half
d1
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
dqh2
[
i
]
=
__halves2half2
(
d0
,
d1
);
}
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
dqh2
[
0
],
y1
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hfma2
(
dqh2
[
1
],
y1
[
0
],
z1
[
0
]);
dq
[
2
]
=
__hfma2
(
dqh2
[
2
],
y1
[
0
],
z1
[
0
]);
dq
[
3
]
=
__hfma2
(
dqh2
[
3
],
y1
[
0
],
z1
[
0
]);
}
else
{
dq
[
0
]
=
__hadd2
(
dqh2
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hadd2
(
dqh2
[
1
],
z1
[
0
]);
dq
[
2
]
=
__hadd2
(
dqh2
[
2
],
z1
[
0
]);
dq
[
3
]
=
__hadd2
(
dqh2
[
3
],
z1
[
0
]);
}
}
}
// namespace gptq
}
// namespace vllm
#endif
#endif
csrc/quantization/gptq/qdq_8.cuh
0 → 100644
View file @
01a5d18a
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
__forceinline__
__device__
void
shuffle_8bit_4
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_8bit_8
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
half2
(
&
dq
)[
4
],
int
stride
,
const
uint32_t
zero
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
8
,
0xff
),
zero
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
+
4
]
=
dq_ns
(
exb
(
q_1
,
i
*
8
,
0xff
),
zero
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
}
// namespace gptq
}
// namespace vllm
#endif
vllm/model_executor/layers/quantization/gptq.py
View file @
01a5d18a
import
enum
import
enum
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
from
fractions
import
Fraction
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
...
@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
self
.
weight_bits
=
weight_bits
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
desc_act
=
desc_act
self
.
pack_factor
=
32
//
self
.
weight_bits
self
.
pack_factor
=
Fraction
(
32
,
self
.
weight_bits
)
# exllama kernel v1 only supports 4 bit
if
self
.
weight_bits
not
in
[
2
,
3
,
4
,
8
]:
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
raise
ValueError
(
"Currently, only
4
-bit weight quantization is supported for "
"Currently, only
2/3/4/8
-bit weight quantization is supported for "
f
"GPTQ, but got
{
self
.
weight_bits
}
bits."
)
f
"GPTQ, but got
{
self
.
weight_bits
}
bits."
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
...
@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
"tensor parallel size."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
.
numerator
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The output size is not aligned with the quantized "
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"weight shape. This can be caused by too large "
...
@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
else
:
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
])
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
self
.
quant_config
.
weight_bits
)
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
)
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
self
.
quant_config
.
weight_bits
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
.
reshape
(
out_shape
)
return
output
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment