Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
564199ba
Unverified
Commit
564199ba
authored
Dec 21, 2023
by
OlivierDehaene
Committed by
GitHub
Dec 21, 2023
Browse files
feat: update exllamav2 kernels (#1370)
Co-authored-by:
Nicolas Patry
<
patry.nicolas@protonmail.com
>
parent
987c959f
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
525 additions
and
255 deletions
+525
-255
server/exllamav2_kernels/exllamav2_kernels/config.h
server/exllamav2_kernels/exllamav2_kernels/config.h
+2
-0
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
+53
-45
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh
+4
-1
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh
...xllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh
+205
-112
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh
...av2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh
+98
-44
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
+47
-23
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh
+3
-1
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh
...llamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh
+2
-0
server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh
server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh
+12
-0
server/exllamav2_kernels/exllamav2_kernels/ext.cpp
server/exllamav2_kernels/exllamav2_kernels/ext.cpp
+5
-0
server/tests/utils/test_hub.py
server/tests/utils/test_hub.py
+4
-4
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+1
-1
server/text_generation_server/utils/gptq/exllamav2.py
server/text_generation_server/utils/gptq/exllamav2.py
+33
-0
server/text_generation_server/utils/hub.py
server/text_generation_server/utils/hub.py
+21
-11
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+4
-2
server/text_generation_server/utils/log.py
server/text_generation_server/utils/log.py
+6
-0
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+25
-11
No files found.
server/exllamav2_kernels/exllamav2_kernels/config.h
View file @
564199ba
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define _config_h
#define _config_h
#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
#define QMODE_2BIT 1
#define QMODE_2BIT 1
#define QMODE_3BIT 1
#define QMODE_3BIT 1
...
@@ -10,4 +11,5 @@
...
@@ -10,4 +11,5 @@
#define QMODE_6BIT 0
#define QMODE_6BIT 0
#define QMODE_8BIT 0
#define QMODE_8BIT 0
#endif
#endif
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
View file @
564199ba
...
@@ -10,16 +10,19 @@
...
@@ -10,16 +10,19 @@
#include "quant/qdq_6.cuh"
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define GPTQ_BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define GPTQ_BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
#define EXL2_BLOCK_KN_SIZE 64
#define EXL2_BLOCK_M_SIZE_MAX 8
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256
#define CLEAR_N_SIZE 256
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh"
#include "q_gemm_kernel_gptq.cuh"
#include "compat_gemm.cuh"
void
gemm_half_q_half_cuda_part
void
gemm_half_q_half_cuda_part
(
(
const
half
*
a
,
const
half
*
a
,
...
@@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
...
@@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
int
size_n
,
int
size_n
,
int
size_k
,
int
size_k
,
int
m_count
,
int
m_count
,
bool
clear
bool
clear
,
const
half
*
r_weights
,
int
r_weights_stride
,
bool
mul_r_weights
)
)
{
{
if
(
!
b
->
is_gptq
)
if
(
!
b
->
is_gptq
)
{
{
dim3
blockDim
,
gridDim
;
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
x
=
EXL2_
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
blockDim
.
y
=
1
;
blockDim
.
z
=
1
;
blockDim
.
z
=
1
;
gridDim
.
x
=
DIVIDE
(
size_n
,
BLOCK_KN_SIZE
*
4
);
gridDim
.
x
=
DIVIDE
(
size_n
,
EXL2_
BLOCK_KN_SIZE
*
4
);
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
gridDim
.
z
=
DIVIDE
(
size_k
,
EXL2_
BLOCK_KN_SIZE
);
fp_gemm_half_q_half_kernel
kernel
=
pick_gemm_half_q_half_kernel
(
true
,
m_count
);
fp_gemm_half_q_half_kernel
kernel
=
pick_gemm_half_q_half_kernel
(
m_count
,
r_weights
!=
NULL
,
mul_r_weights
);
kernel
<<<
gridDim
,
blockDim
>>>
kernel
<<<
gridDim
,
blockDim
>>>
(
(
...
@@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
...
@@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
size_n
,
size_n
,
size_k
,
size_k
,
b
->
groups
,
b
->
groups
,
b
->
group
size
,
b
->
cuda_q_
group
_map
,
b
->
cuda_q_perm
,
b
->
cuda_q_perm
,
b
->
rows_8
,
b
->
rows_8
,
b
->
rows_6
,
b
->
rows_6
,
...
@@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
...
@@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
b
->
rows_4
,
b
->
rows_4
,
b
->
rows_3
,
b
->
rows_3
,
b
->
rows_2
,
b
->
rows_2
,
clear
clear
,
r_weights
,
r_weights_stride
);
);
}
}
else
else
{
{
dim3
blockDim
,
gridDim
;
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
x
=
GPTQ_
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
blockDim
.
y
=
1
;
blockDim
.
z
=
1
;
blockDim
.
z
=
1
;
gridDim
.
x
=
DIVIDE
(
size_n
,
BLOCK_KN_SIZE
*
4
);
gridDim
.
x
=
DIVIDE
(
size_n
,
GPTQ_
BLOCK_KN_SIZE
*
4
);
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
gridDim
.
z
=
DIVIDE
(
size_k
,
GPTQ_
BLOCK_KN_SIZE
);
fp_gemm_half_q_half_gptq_kernel
kernel
=
pick_gemm_half_q_half_gptq_kernel
(
true
,
m_count
);
fp_gemm_half_q_half_gptq_kernel
kernel
=
pick_gemm_half_q_half_gptq_kernel
(
m_count
,
r_weights
!=
NULL
,
mul_r_weights
);
// DBGX((uint64_t) b->cuda_q_perm);
// DBGX((uint64_t) r_weights);
// DBGI(b->rows_4);
// if (r_weights)
// DBGI(b->height);
// print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);
kernel
<<<
gridDim
,
blockDim
>>>
kernel
<<<
gridDim
,
blockDim
>>>
(
(
...
@@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
...
@@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
size_n
,
size_n
,
size_k
,
size_k
,
b
->
groups
,
b
->
groups
,
b
->
groupsize
,
b
->
gptq_
groupsize
,
b
->
cuda_q_perm
,
b
->
cuda_q_perm
,
b
->
rows_4
,
b
->
rows_4
,
clear
clear
,
r_weights
,
r_weights_stride
);
);
}
}
}
}
...
@@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
...
@@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
int
size_k
,
int
size_k
,
bool
clear
,
bool
clear
,
half
*
temp_dq
,
half
*
temp_dq
,
bool
force_cuda
bool
force_cuda
,
const
half
*
r_weights
,
const
int
r_weights_stride
,
bool
mul_r_weights
)
)
{
{
if
(
size_m
>
MAX_Q_GEMM_ROWS
&&
!
force_cuda
)
if
(
size_m
>
MAX_Q_GEMM_ROWS
&&
!
force_cuda
)
{
{
//printf("cublas\n");
// Reconstruct FP16 matrix, then cuBLAS
// Reconstruct FP16 matrix, then cuBLAS
if
(
!
temp_dq
)
temp_dq
=
b
->
temp_dq
;
if
(
!
temp_dq
)
temp_dq
=
b
->
temp_dq
;
...
@@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
...
@@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
//const float alpha = 1.0f;
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//cublasSgemmEx(cublas_handle,
//cublasSgemmEx(cublas_handle,
//
CUBLAS_OP_N,
// CUBLAS_OP_N,
//
CUBLAS_OP_N,
// CUBLAS_OP_N,
//
size_n, size_m, size_k,
// size_n, size_m, size_k,
//
&alpha, temp_dq, CUDA_R_16F, size_n,
// &alpha, temp_dq, CUDA_R_16F, size_n,
//
a, CUDA_R_16F, size_k,
// a, CUDA_R_16F, size_k,
//
&beta, c, CUDA_R_16F, size_n);
// &beta, c, CUDA_R_16F, size_n);
//const float alpha = 1.0f;
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
...
@@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
...
@@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
}
}
else
else
{
{
//printf("cuda\n");
// Quantized matmul
// Quantized matmul
//if (clear) clear_tensor_cuda(c, size_m, size_n);
int
block_m_size_max
=
b
->
is_gptq
?
GPTQ_BLOCK_M_SIZE_MAX
:
EXL2_BLOCK_M_SIZE_MAX
;
int
max_chunks
=
size_m
/
block_m_size_max
;
int
max_chunks
=
size_m
/
BLOCK_M_SIZE_MAX
;
int
last_chunk
=
max_chunks
*
block_m_size_max
;
int
last_chunk
=
max_chunks
*
BLOCK_M_SIZE_MAX
;
int
last_chunk_size
=
size_m
-
last_chunk
;
int
last_chunk_size
=
size_m
-
last_chunk
;
if
(
max_chunks
)
if
(
max_chunks
)
{
{
gemm_half_q_half_cuda_part
(
a
,
b
,
c
,
last_chunk
,
size_n
,
size_k
,
BLOCK_M_SIZE_MAX
,
clear
);
gemm_half_q_half_cuda_part
(
a
,
b
,
c
,
last_chunk
,
size_n
,
size_k
,
block_m_size_max
,
clear
,
r_weights
,
r_weights_stride
,
mul_r_weights
);
}
}
if
(
last_chunk_size
)
if
(
last_chunk_size
)
{
{
gemm_half_q_half_cuda_part
(
a
+
last_chunk
*
size_k
,
b
,
c
+
last_chunk
*
size_n
,
last_chunk_size
,
size_n
,
size_k
,
last_chunk_size
,
clear
);
gemm_half_q_half_cuda_part
(
a
+
last_chunk
*
size_k
,
b
,
c
+
last_chunk
*
size_n
,
last_chunk_size
,
size_n
,
size_k
,
last_chunk_size
,
clear
,
r_weights
,
r_weights_stride
,
mul_r_weights
);
}
}
}
}
}
}
...
@@ -201,11 +210,10 @@ void clear_tensor_cuda
...
@@ -201,11 +210,10 @@ void clear_tensor_cuda
int
size_n
int
size_n
)
)
{
{
return
;
// dim3 blockDim, gridDim;
dim3
blockDim
,
gridDim
;
// blockDim.x = CLEAR_N_SIZE;
blockDim
.
x
=
CLEAR_N_SIZE
;
// blockDim.y = 1;
blockDim
.
y
=
1
;
// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
gridDim
.
x
=
DIVIDE
(
size_n
/
8
,
CLEAR_N_SIZE
);
// gridDim.y = size_m;
gridDim
.
y
=
size_m
;
// clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
clear_kernel
<<<
gridDim
,
blockDim
>>>
(
c
,
size_m
,
size_n
);
}
}
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh
View file @
564199ba
...
@@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
...
@@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
int
size_k
,
int
size_k
,
bool
clear
=
false
,
bool
clear
=
false
,
half
*
reconstruct
=
NULL
,
half
*
reconstruct
=
NULL
,
bool
force_cuda
=
false
bool
force_cuda
=
false
,
const
half
*
r_weights
=
NULL
,
const
int
r_weights_stride
=
0
,
bool
mul_r_weights
=
false
);
);
void
clear_tensor_cuda
void
clear_tensor_cuda
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh
View file @
564199ba
#include "compat.cuh"
#include "compat.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__forceinline__
__device__
half2
dot22_8
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
__forceinline__
__device__
half2
dot22_8
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
{
half2
result
=
{};
half2
result
=
{};
...
@@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c
...
@@ -60,6 +57,47 @@ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, c
return
fma
(
result_f
,
qs_f
,
g_result
);
return
fma
(
result_f
,
qs_f
,
g_result
);
}
}
__forceinline__
__device__
half
dot22_8_h
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half
g_result
,
const
half
qs_h
)
{
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
float
result
=
{};
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
half2
w01
=
dq
[
i
];
float
w0
=
__low2float
(
w01
);
float
w1
=
__high2float
(
w01
);
float
x0
=
__half2float
(
*
a_ptr
++
);
float
x1
=
__half2float
(
*
a_ptr
++
);
result
=
fma
(
w0
,
x0
,
result
);
result
=
fma
(
w1
,
x1
,
result
);
}
float
qs
=
__half2float
(
qs_h
);
result
*=
qs
;
half
result_h
=
__float2half_rn
(
result
);
return
__hadd
(
result_h
,
g_result
);
}
__forceinline__
__device__
half
dot22_16_h
(
half2
(
&
dq
)[
8
],
const
half
*
a_ptr
,
const
half
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
half
result_h
=
__hadd
(
__low2half
(
result
),
__high2half
(
result
));
return
__hfma
(
result_h
,
qs_h
,
g_result
);
}
__forceinline__
__device__
half
dot22_32_h
(
half2
(
&
dq
)[
16
],
const
half
*
a_ptr
,
const
half
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
+=
1
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
half
result_h
=
__hadd
(
__low2half
(
result
),
__high2half
(
result
));
return
__hfma
(
result_h
,
qs_h
,
g_result
);
}
typedef
void
(
*
fp_gemm_half_q_half_kernel
)
typedef
void
(
*
fp_gemm_half_q_half_kernel
)
...
@@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel)
...
@@ -73,7 +111,7 @@ typedef void (*fp_gemm_half_q_half_kernel)
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
u
int
16_t
*
,
const
uint16_t
*
,
const
uint16_t
*
,
const
int
,
const
int
,
const
int
,
const
int
,
...
@@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel)
...
@@ -81,10 +119,12 @@ typedef void (*fp_gemm_half_q_half_kernel)
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
bool
const
bool
,
const
half
*
,
const
int
);
);
template
<
bool
first_block
,
int
m_count
>
template
<
int
m_count
,
bool
use_r_weights
,
bool
mul_r_weights
>
__global__
void
gemm_half_q_half_kernel
__global__
void
gemm_half_q_half_kernel
(
(
const
half
*
__restrict__
a
,
const
half
*
__restrict__
a
,
...
@@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel
...
@@ -96,7 +136,7 @@ __global__ void gemm_half_q_half_kernel
const
int
size_n
,
const
int
size_n
,
const
int
size_k
,
const
int
size_k
,
const
int
groups
,
const
int
groups
,
const
int
group
size
,
const
u
int
16_t
*
__restrict__
b_q_
group
_map
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
int
rows_8
,
const
int
rows_8
,
const
int
rows_6
,
const
int
rows_6
,
...
@@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel
...
@@ -104,7 +144,9 @@ __global__ void gemm_half_q_half_kernel
const
int
rows_4
,
const
int
rows_4
,
const
int
rows_3
,
const
int
rows_3
,
const
int
rows_2
,
const
int
rows_2
,
const
bool
clear
const
bool
clear
,
const
half
*
r_weights
,
const
int
r_weights_stride
)
)
{
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
...
@@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel
...
@@ -115,18 +157,34 @@ __global__ void gemm_half_q_half_kernel
// Block
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_n
=
blockIdx
.
x
*
EXL2_
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
offset_k
=
blockIdx
.
z
*
EXL2_
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_n
=
min
(
offset_n
+
EXL2_
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
end_k
=
min
(
offset_k
+
EXL2_
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
int
n
=
offset_n
+
t
*
4
;
// Read weights
half_uint16
weights
[
MAX_Q_GEMM_WEIGHTS
];
if
constexpr
(
use_r_weights
)
{
uint16_t
any_w
=
0
;
const
half
*
w_ptr
=
r_weights
;
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
weights
[
m
].
as_half
=
*
w_ptr
;
w_ptr
+=
r_weights_stride
;
any_w
|=
weights
[
m
].
as_uint16
;
}
if
(
!
any_w
)
return
;
// Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
__shared__
half
block_a
[
m_count
][
EXL2_
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
if
(
offset_k
+
t
<
end_k
)
{
{
...
@@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel
...
@@ -135,6 +193,7 @@ __global__ void gemm_half_q_half_kernel
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
half
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
// half a0 = a_ptr[offset_k + t];
block_a_ptr
[
t
]
=
a0
;
block_a_ptr
[
t
]
=
a0
;
}
}
}
}
...
@@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel
...
@@ -153,14 +212,19 @@ __global__ void gemm_half_q_half_kernel
// Find initial group
// Find initial group
int
group
=
offset_k
/
groupsize
;
//int group = offset_k / groupsize;
int
group
=
b_q_group_map
[
offset_k
*
2
];
// if (offset_m == 0 && t == 0)
// DBGI2(offset_k, group);
// Preload scales
// Preload scales
float
scales
[
MAX_GROUPS_IN_BLOCK
][
4
];
half
scales
[
EXL2_
MAX_GROUPS_IN_BLOCK
][
4
];
int
groups_in_block
=
DIVIDE
((
end_k
-
offset_k
),
groupsize
);
//int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
for
(
int
g
=
0
;
g
<
groups_in_block
;
g
++
)
int
temp_k
=
offset_k
;
for
(
int
g
=
0
;
temp_k
<
end_k
;
g
++
)
{
{
int
qscales
[
4
];
int
qscales
[
4
];
b_q_scale_
.
item4
(
qscales
,
group
+
g
,
n
);
b_q_scale_
.
item4
(
qscales
,
group
+
g
,
n
);
...
@@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel
...
@@ -168,11 +232,12 @@ __global__ void gemm_half_q_half_kernel
qscales
[
1
]
++
;
qscales
[
1
]
++
;
qscales
[
2
]
++
;
qscales
[
2
]
++
;
qscales
[
3
]
++
;
qscales
[
3
]
++
;
float
maxscale
=
__half2float
(
b_q_scale_max
[
group
+
g
]);
half
maxscale
=
b_q_scale_max
[
group
+
g
];
scales
[
g
][
0
]
=
__int2float_rn
(
qscales
[
0
]
*
qscales
[
0
])
*
maxscale
;
scales
[
g
][
0
]
=
__hmul
(
__int2half_rn
(
qscales
[
0
]
*
qscales
[
0
]),
maxscale
);
scales
[
g
][
1
]
=
__int2float_rn
(
qscales
[
1
]
*
qscales
[
1
])
*
maxscale
;
scales
[
g
][
1
]
=
__hmul
(
__int2half_rn
(
qscales
[
1
]
*
qscales
[
1
]),
maxscale
);
scales
[
g
][
2
]
=
__int2float_rn
(
qscales
[
2
]
*
qscales
[
2
])
*
maxscale
;
scales
[
g
][
2
]
=
__hmul
(
__int2half_rn
(
qscales
[
2
]
*
qscales
[
2
]),
maxscale
);
scales
[
g
][
3
]
=
__int2float_rn
(
qscales
[
3
]
*
qscales
[
3
])
*
maxscale
;
scales
[
g
][
3
]
=
__hmul
(
__int2half_rn
(
qscales
[
3
]
*
qscales
[
3
]),
maxscale
);
temp_k
+=
b_q_group_map
[
temp_k
*
2
+
1
];
}
}
// a, b offset
// a, b offset
...
@@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel
...
@@ -193,20 +258,20 @@ __global__ void gemm_half_q_half_kernel
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
int
a_stride
=
EXL2_
BLOCK_KN_SIZE
;
// Initial group
// Initial group
int
scales_idx
=
0
;
int
scales_idx
=
0
;
float
qs_
f
0
=
scales
[
scales_idx
][
0
];
half
qs_
h
0
=
scales
[
scales_idx
][
0
];
float
qs_
f
1
=
scales
[
scales_idx
][
1
];
half
qs_
h
1
=
scales
[
scales_idx
][
1
];
float
qs_
f
2
=
scales
[
scales_idx
][
2
];
half
qs_
h
2
=
scales
[
scales_idx
][
2
];
float
qs_
f
3
=
scales
[
scales_idx
][
3
];
half
qs_
h
3
=
scales
[
scales_idx
][
3
];
int
nextgroup
=
offset_k
+
group
size
;
int
nextgroup
=
offset_k
+
b_q_
group
_map
[
offset_k
*
2
+
1
]
;
// Column result
// Column result
float
block_c
[
m_count
][
4
]
=
{};
half
block_c
[
m_count
][
4
]
=
{};
// Dequantize groups
// Dequantize groups
...
@@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -218,11 +283,11 @@ __global__ void gemm_half_q_half_kernel
{
{
group
++
;
group
++
;
scales_idx
++
;
scales_idx
++
;
qs_
f
0
=
scales
[
scales_idx
][
0
];
qs_
h
0
=
scales
[
scales_idx
][
0
];
qs_
f
1
=
scales
[
scales_idx
][
1
];
qs_
h
1
=
scales
[
scales_idx
][
1
];
qs_
f
2
=
scales
[
scales_idx
][
2
];
qs_
h
2
=
scales
[
scales_idx
][
2
];
qs_
f
3
=
scales
[
scales_idx
][
3
];
qs_
h
3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
group
size
;
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
}
}
#pragma unroll
#pragma unroll
...
@@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -240,10 +305,11 @@ __global__ void gemm_half_q_half_kernel
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
block_c
[
m
][
0
]
=
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
if
constexpr
(
use_r_weights
)
{
if
(
!
weights
[
m
].
as_uint16
)
continue
;
}
block_c
[
m
][
1
]
=
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
0
]
=
dot22_8_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_h0
);
block_c
[
m
][
2
]
=
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
1
]
=
dot22_8_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_h1
);
block_c
[
m
][
3
]
=
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
block_c
[
m
][
2
]
=
dot22_8_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_h2
);
block_c
[
m
][
3
]
=
dot22_8_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_h3
);
}
}
a_ptr
+=
8
;
a_ptr
+=
8
;
}
}
...
@@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -256,11 +322,11 @@ __global__ void gemm_half_q_half_kernel
{
{
group
++
;
group
++
;
scales_idx
++
;
scales_idx
++
;
qs_
f
0
=
scales
[
scales_idx
][
0
];
qs_
h
0
=
scales
[
scales_idx
][
0
];
qs_
f
1
=
scales
[
scales_idx
][
1
];
qs_
h
1
=
scales
[
scales_idx
][
1
];
qs_
f
2
=
scales
[
scales_idx
][
2
];
qs_
h
2
=
scales
[
scales_idx
][
2
];
qs_
f
3
=
scales
[
scales_idx
][
3
];
qs_
h
3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
group
size
;
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
}
}
#pragma unroll
#pragma unroll
...
@@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -279,10 +345,11 @@ __global__ void gemm_half_q_half_kernel
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
block_c
[
m
][
0
]
=
dot22_16_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
if
constexpr
(
use_r_weights
)
{
if
(
!
weights
[
m
].
as_uint16
)
continue
;
}
block_c
[
m
][
1
]
=
dot22_16_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
0
]
=
dot22_16_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_h0
);
block_c
[
m
][
2
]
=
dot22_16_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
1
]
=
dot22_16_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_h1
);
block_c
[
m
][
3
]
=
dot22_16_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
block_c
[
m
][
2
]
=
dot22_16_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_h2
);
block_c
[
m
][
3
]
=
dot22_16_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_h3
);
}
}
a_ptr
+=
16
;
a_ptr
+=
16
;
}
}
...
@@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -295,11 +362,11 @@ __global__ void gemm_half_q_half_kernel
{
{
group
++
;
group
++
;
scales_idx
++
;
scales_idx
++
;
qs_
f
0
=
scales
[
scales_idx
][
0
];
qs_
h
0
=
scales
[
scales_idx
][
0
];
qs_
f
1
=
scales
[
scales_idx
][
1
];
qs_
h
1
=
scales
[
scales_idx
][
1
];
qs_
f
2
=
scales
[
scales_idx
][
2
];
qs_
h
2
=
scales
[
scales_idx
][
2
];
qs_
f
3
=
scales
[
scales_idx
][
3
];
qs_
h
3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
group
size
;
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
}
}
#pragma unroll
#pragma unroll
...
@@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -320,10 +387,11 @@ __global__ void gemm_half_q_half_kernel
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
block_c
[
m
][
0
]
=
dot22_32_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
if
constexpr
(
use_r_weights
)
{
if
(
!
weights
[
m
].
as_uint16
)
continue
;
}
block_c
[
m
][
1
]
=
dot22_32_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
0
]
=
dot22_32_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_h0
);
block_c
[
m
][
2
]
=
dot22_32_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
1
]
=
dot22_32_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_h1
);
block_c
[
m
][
3
]
=
dot22_32_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
block_c
[
m
][
2
]
=
dot22_32_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_h2
);
block_c
[
m
][
3
]
=
dot22_32_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_h3
);
}
}
a_ptr
+=
32
;
a_ptr
+=
32
;
}
}
...
@@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -337,11 +405,11 @@ __global__ void gemm_half_q_half_kernel
{
{
group
++
;
group
++
;
scales_idx
++
;
scales_idx
++
;
qs_
f
0
=
scales
[
scales_idx
][
0
];
qs_
h
0
=
scales
[
scales_idx
][
0
];
qs_
f
1
=
scales
[
scales_idx
][
1
];
qs_
h
1
=
scales
[
scales_idx
][
1
];
qs_
f
2
=
scales
[
scales_idx
][
2
];
qs_
h
2
=
scales
[
scales_idx
][
2
];
qs_
f
3
=
scales
[
scales_idx
][
3
];
qs_
h
3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
group
size
;
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
}
}
#pragma unroll
#pragma unroll
...
@@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -358,10 +426,11 @@ __global__ void gemm_half_q_half_kernel
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
block_c
[
m
][
0
]
=
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
if
constexpr
(
use_r_weights
)
{
if
(
!
weights
[
m
].
as_uint16
)
continue
;
}
block_c
[
m
][
1
]
=
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
0
]
=
dot22_8_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_h0
);
block_c
[
m
][
2
]
=
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
1
]
=
dot22_8_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_h1
);
block_c
[
m
][
3
]
=
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
block_c
[
m
][
2
]
=
dot22_8_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_h2
);
block_c
[
m
][
3
]
=
dot22_8_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_h3
);
}
}
a_ptr
+=
8
;
a_ptr
+=
8
;
}
}
...
@@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -374,11 +443,11 @@ __global__ void gemm_half_q_half_kernel
{
{
group
++
;
group
++
;
scales_idx
++
;
scales_idx
++
;
qs_
f
0
=
scales
[
scales_idx
][
0
];
qs_
h
0
=
scales
[
scales_idx
][
0
];
qs_
f
1
=
scales
[
scales_idx
][
1
];
qs_
h
1
=
scales
[
scales_idx
][
1
];
qs_
f
2
=
scales
[
scales_idx
][
2
];
qs_
h
2
=
scales
[
scales_idx
][
2
];
qs_
f
3
=
scales
[
scales_idx
][
3
];
qs_
h
3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
group
size
;
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
}
}
#pragma unroll
#pragma unroll
...
@@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel
...
@@ -397,10 +466,11 @@ __global__ void gemm_half_q_half_kernel
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
block_c
[
m
][
0
]
=
dot22_32_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
if
constexpr
(
use_r_weights
)
{
if
(
!
weights
[
m
].
as_uint16
)
continue
;
}
block_c
[
m
][
1
]
=
dot22_32_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
0
]
=
dot22_32_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_h0
);
block_c
[
m
][
2
]
=
dot22_32_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
1
]
=
dot22_32_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_h1
);
block_c
[
m
][
3
]
=
dot22_32_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
block_c
[
m
][
2
]
=
dot22_32_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_h2
);
block_c
[
m
][
3
]
=
dot22_32_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_h3
);
}
}
a_ptr
+=
32
;
a_ptr
+=
32
;
}
}
...
@@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel
...
@@ -413,15 +483,15 @@ __global__ void gemm_half_q_half_kernel
{
{
group
++
;
group
++
;
scales_idx
++
;
scales_idx
++
;
qs_
f
0
=
scales
[
scales_idx
][
0
];
qs_
h
0
=
scales
[
scales_idx
][
0
];
qs_
f
1
=
scales
[
scales_idx
][
1
];
qs_
h
1
=
scales
[
scales_idx
][
1
];
qs_
f
2
=
scales
[
scales_idx
][
2
];
qs_
h
2
=
scales
[
scales_idx
][
2
];
qs_
f
3
=
scales
[
scales_idx
][
3
];
qs_
h
3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
group
size
;
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
}
}
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
for
(
int
j
=
0
;
j
<
1
;
j
++
)
{
{
int4
load_int4
[
1
];
int4
load_int4
[
1
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
...
@@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel
...
@@ -434,15 +504,16 @@ __global__ void gemm_half_q_half_kernel
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
block_c
[
m
][
0
]
=
dot22_16_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
if
constexpr
(
use_r_weights
)
{
if
(
!
weights
[
m
].
as_uint16
)
continue
;
}
block_c
[
m
][
1
]
=
dot22_16_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
0
]
=
dot22_16_h
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_h0
);
block_c
[
m
][
2
]
=
dot22_16_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
1
]
=
dot22_16_h
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_h1
);
block_c
[
m
][
3
]
=
dot22_16_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
block_c
[
m
][
2
]
=
dot22_16_h
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_h2
);
block_c
[
m
][
3
]
=
dot22_16_h
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_h3
);
}
}
a_ptr
+=
16
;
a_ptr
+=
16
;
}
}
k
+=
32
;
k
+=
16
;
}
}
// Accumulate column sums in c
// Accumulate column sums in c
...
@@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel
...
@@ -450,38 +521,60 @@ __global__ void gemm_half_q_half_kernel
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
0
]),
__float2half_rn
(
block_c
[
m
][
1
]));
half2
result01
=
__halves2half2
(
block_c
[
m
][
0
],
block_c
[
m
][
1
]);
half2
result23
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
2
]),
__float2half_rn
(
block_c
[
m
][
3
]));
half2
result23
=
__halves2half2
(
block_c
[
m
][
2
],
block_c
[
m
][
3
]);
if
constexpr
(
mul_r_weights
)
{
half2
w_mul2
=
__half2half2
(
weights
[
m
].
as_half
);
result01
=
__hmul2
(
result01
,
w_mul2
);
result23
=
__hmul2
(
result23
,
w_mul2
);
}
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
atomicAdd
(
out
+
1
,
result23
);
// *out = result01;
// *(out + 1) = result23;
}
}
}
}
fp_gemm_half_q_half_kernel
pick_gemm_half_q_half_kernel
(
bool
first_block
,
const
int
m_count
)
template
<
bool
use_r_weights
,
bool
mul_r_weights
>
struct
map_m_count_exl2
{
static
constexpr
fp_gemm_half_q_half_kernel
pick_gemm_half_q_half_kernel
(
const
int
m_count
)
{
#if EXL2_BLOCK_M_SIZE_MAX >= 1
if
(
m_count
==
1
)
return
gemm_half_q_half_kernel
<
1
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 2
if
(
m_count
==
2
)
return
gemm_half_q_half_kernel
<
2
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_kernel
<
3
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_kernel
<
4
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_kernel
<
5
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_kernel
<
6
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_kernel
<
7
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if EXL2_BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_kernel
<
8
,
use_r_weights
,
mul_r_weights
>
;
#endif
return
NULL
;
}
};
fp_gemm_half_q_half_kernel
pick_gemm_half_q_half_kernel
(
const
int
m_count
,
bool
r_weights
,
bool
mul_r_weights
)
{
{
#if BLOCK_M_SIZE_MAX >= 1
if
(
!
r_weights
&&
!
mul_r_weights
)
return
map_m_count_exl2
<
false
,
false
>::
pick_gemm_half_q_half_kernel
(
m_count
);
if
(
m_count
==
1
)
return
gemm_half_q_half_kernel
<
true
,
1
>
;
if
(
!
r_weights
&&
mul_r_weights
)
return
map_m_count_exl2
<
false
,
true
>::
pick_gemm_half_q_half_kernel
(
m_count
);
#endif
if
(
r_weights
&&
!
mul_r_weights
)
return
map_m_count_exl2
<
true
,
false
>::
pick_gemm_half_q_half_kernel
(
m_count
);
#if BLOCK_M_SIZE_MAX >= 2
if
(
r_weights
&&
mul_r_weights
)
return
map_m_count_exl2
<
true
,
true
>::
pick_gemm_half_q_half_kernel
(
m_count
);
if
(
m_count
==
2
)
return
gemm_half_q_half_kernel
<
true
,
2
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_kernel
<
true
,
3
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_kernel
<
true
,
4
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_kernel
<
true
,
5
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_kernel
<
true
,
6
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_kernel
<
true
,
7
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_kernel
<
true
,
8
>
;
#endif
return
NULL
;
return
NULL
;
}
}
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh
View file @
564199ba
...
@@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
...
@@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
}
}
__forceinline__
__device__
half2
dot22_8_h2
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
result
;
}
typedef
void
(
*
fp_gemm_half_q_half_gptq_kernel
)
typedef
void
(
*
fp_gemm_half_q_half_gptq_kernel
)
(
(
const
half
*
,
const
half
*
,
...
@@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
...
@@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const
int
,
const
int
,
const
uint16_t
*
,
const
uint16_t
*
,
const
int
,
const
int
,
const
bool
const
bool
,
const
half
*
,
const
int
);
);
template
<
bool
first_block
,
int
m_count
>
template
<
int
m_count
,
bool
use_r_weights
,
bool
mul_r_weights
>
__global__
void
gemm_half_q_half_gptq_kernel
__global__
void
gemm_half_q_half_gptq_kernel
(
(
const
half
*
__restrict__
a
,
const
half
*
__restrict__
a
,
...
@@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
...
@@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
const
int
groupsize
,
const
int
groupsize
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
int
rows_4
,
const
int
rows_4
,
const
bool
clear
const
bool
clear
,
const
half
*
r_weights
,
const
int
r_weights_stride
)
)
{
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
...
@@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
...
@@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
// Block
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_n
=
blockIdx
.
x
*
GPTQ_
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
offset_k
=
blockIdx
.
z
*
GPTQ_
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_n
=
min
(
offset_n
+
GPTQ_
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
end_k
=
min
(
offset_k
+
GPTQ_
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
int
n
=
offset_n
+
t
*
4
;
// Read weights
half_uint16
weights
[
MAX_Q_GEMM_WEIGHTS
];
if
constexpr
(
use_r_weights
)
{
uint16_t
any_w
=
0
;
const
half
*
w_ptr
=
r_weights
;
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
weights
[
m
].
as_half
=
*
w_ptr
;
w_ptr
+=
r_weights_stride
;
any_w
|=
weights
[
m
].
as_uint16
;
}
if
(
!
any_w
)
return
;
// Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
__shared__
half
block_a
[
m_count
][
GPTQ_
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
if
(
offset_k
+
t
<
end_k
)
{
{
...
@@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
...
@@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
int
a_stride
=
GPTQ_
BLOCK_KN_SIZE
;
// Initial group
// Initial group
int
zeros
[
4
];
int
zeros
[
4
];
float
scales
[
4
];
half2
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_
f
(
scales
,
group
,
n
);
b_gptq_scales_
.
item4_
h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
...
@@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
...
@@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
// Column result
// Column result
float
block_c
[
m_count
][
4
]
=
{};
half2
block_c
[
m_count
][
4
]
=
{};
// Dequantize and multiply
// Dequantize and multiply
...
@@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
...
@@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_
f
(
scales
,
group
,
n
);
b_gptq_scales_
.
item4_
h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
...
@@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
...
@@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
#pragma unroll
#pragma unroll
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
block_c
[
m
][
0
]
=
fma
(
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
),
scales
[
0
],
block_c
[
m
][
0
]);
if
constexpr
(
use_r_weights
)
{
if
(
!
weights
[
m
].
as_uint16
)
continue
;
}
block_c
[
m
][
1
]
=
fma
(
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
),
scales
[
1
],
block_c
[
m
][
1
]);
block_c
[
m
][
0
]
=
__hfma2
(
dot22_8_h2
(
dq
[
0
],
a_ptr
+
m
*
a_stride
),
scales
[
0
],
block_c
[
m
][
0
]);
block_c
[
m
][
2
]
=
fma
(
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
),
scales
[
2
],
block_c
[
m
][
2
]);
block_c
[
m
][
1
]
=
__hfma2
(
dot22_8_h2
(
dq
[
1
],
a_ptr
+
m
*
a_stride
),
scales
[
1
],
block_c
[
m
][
1
]);
block_c
[
m
][
3
]
=
fma
(
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
),
scales
[
3
],
block_c
[
m
][
3
]);
block_c
[
m
][
2
]
=
__hfma2
(
dot22_8_h2
(
dq
[
2
],
a_ptr
+
m
*
a_stride
),
scales
[
2
],
block_c
[
m
][
2
]);
block_c
[
m
][
3
]
=
__hfma2
(
dot22_8_h2
(
dq
[
3
],
a_ptr
+
m
*
a_stride
),
scales
[
3
],
block_c
[
m
][
3
]);
}
}
b_ptr
+=
size_n
;
b_ptr
+=
size_n
;
...
@@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
...
@@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
0
]),
__float2half_rn
(
block_c
[
m
][
1
]));
half
result0
=
__hadd
(
__low2half
(
block_c
[
m
][
0
]),
__high2half
(
block_c
[
m
][
0
]));
half2
result23
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
2
]),
__float2half_rn
(
block_c
[
m
][
3
]));
half
result1
=
__hadd
(
__low2half
(
block_c
[
m
][
1
]),
__high2half
(
block_c
[
m
][
1
]));
half
result2
=
__hadd
(
__low2half
(
block_c
[
m
][
2
]),
__high2half
(
block_c
[
m
][
2
]));
half
result3
=
__hadd
(
__low2half
(
block_c
[
m
][
3
]),
__high2half
(
block_c
[
m
][
3
]));
half2
result01
=
__halves2half2
(
result0
,
result1
);
half2
result23
=
__halves2half2
(
result2
,
result3
);
if
constexpr
(
mul_r_weights
)
{
half2
w_mul2
=
__half2half2
(
weights
[
m
].
as_half
);
result01
=
__hmul2
(
result01
,
w_mul2
);
result23
=
__hmul2
(
result23
,
w_mul2
);
}
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
atomicAdd
(
out
+
1
,
result23
);
}
}
}
}
fp_gemm_half_q_half_gptq_kernel
pick_gemm_half_q_half_gptq_kernel
(
bool
first_block
,
const
int
m_count
)
template
<
bool
use_r_weights
,
bool
mul_r_weights
>
struct
map_m_count_gptq
{
static
constexpr
fp_gemm_half_q_half_gptq_kernel
pick_gemm_half_q_half_gptq_kernel
(
int
m_count
)
{
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
if
(
m_count
==
1
)
return
gemm_half_q_half_gptq_kernel
<
1
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
if
(
m_count
==
2
)
return
gemm_half_q_half_gptq_kernel
<
2
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_gptq_kernel
<
3
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_gptq_kernel
<
4
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_gptq_kernel
<
5
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_gptq_kernel
<
6
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_gptq_kernel
<
7
,
use_r_weights
,
mul_r_weights
>
;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_gptq_kernel
<
8
,
use_r_weights
,
mul_r_weights
>
;
#endif
return
NULL
;
}
};
fp_gemm_half_q_half_gptq_kernel
pick_gemm_half_q_half_gptq_kernel
(
const
int
m_count
,
bool
r_weights
,
bool
mul_r_weights
)
{
{
#if BLOCK_M_SIZE_MAX >= 1
if
(
!
r_weights
&&
!
mul_r_weights
)
return
map_m_count_gptq
<
false
,
false
>::
pick_gemm_half_q_half_gptq_kernel
(
m_count
);
if
(
m_count
==
1
)
return
gemm_half_q_half_gptq_kernel
<
true
,
1
>
;
if
(
!
r_weights
&&
mul_r_weights
)
return
map_m_count_gptq
<
false
,
true
>::
pick_gemm_half_q_half_gptq_kernel
(
m_count
);
#endif
if
(
r_weights
&&
!
mul_r_weights
)
return
map_m_count_gptq
<
true
,
false
>::
pick_gemm_half_q_half_gptq_kernel
(
m_count
);
#if BLOCK_M_SIZE_MAX >= 2
if
(
r_weights
&&
mul_r_weights
)
return
map_m_count_gptq
<
true
,
true
>::
pick_gemm_half_q_half_gptq_kernel
(
m_count
);
if
(
m_count
==
2
)
return
gemm_half_q_half_gptq_kernel
<
true
,
2
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_gptq_kernel
<
true
,
3
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_gptq_kernel
<
true
,
4
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_gptq_kernel
<
true
,
5
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_gptq_kernel
<
true
,
6
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_gptq_kernel
<
true
,
7
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_gptq_kernel
<
true
,
8
>
;
#endif
return
NULL
;
return
NULL
;
}
}
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
View file @
564199ba
...
@@ -57,6 +57,7 @@ QMatrix::QMatrix
...
@@ -57,6 +57,7 @@ QMatrix::QMatrix
uint32_t
*
_q_scale
,
uint32_t
*
_q_scale
,
half
*
_q_scale_max
,
half
*
_q_scale_max
,
uint16_t
*
_q_groups
,
uint16_t
*
_q_groups
,
uint16_t
*
_q_group_map
,
uint32_t
*
_gptq_qzeros
,
uint32_t
*
_gptq_qzeros
,
half
*
_gptq_scales
,
half
*
_gptq_scales
,
...
@@ -80,13 +81,17 @@ QMatrix::QMatrix
...
@@ -80,13 +81,17 @@ QMatrix::QMatrix
cuda_q_scale
=
_q_scale
;
cuda_q_scale
=
_q_scale
;
cuda_q_scale_max
=
_q_scale_max
;
cuda_q_scale_max
=
_q_scale_max
;
cuda_q_groups
=
_q_groups
;
cuda_q_groups
=
_q_groups
;
cuda_q_group_map
=
_q_group_map
;
cuda_gptq_qzeros
=
_gptq_qzeros
;
cuda_gptq_qzeros
=
_gptq_qzeros
;
cuda_gptq_scales
=
_gptq_scales
;
cuda_gptq_scales
=
_gptq_scales
;
is_gptq
=
(
_gptq_qzeros
!=
NULL
);
is_gptq
=
(
_gptq_qzeros
!=
NULL
);
groupsize
=
1
;
if
(
is_gptq
)
while
(
groupsize
*
groups
<
height
)
groupsize
*=
2
;
{
gptq_groupsize
=
1
;
while
(
gptq_groupsize
*
groups
<
height
)
gptq_groupsize
*=
2
;
}
// Create group map
// Create group map
...
@@ -102,15 +107,26 @@ QMatrix::QMatrix
...
@@ -102,15 +107,26 @@ QMatrix::QMatrix
uint16_t
*
cpu_q_groups
=
(
uint16_t
*
)
calloc
(
groups
*
2
,
sizeof
(
uint16_t
));
uint16_t
*
cpu_q_groups
=
(
uint16_t
*
)
calloc
(
groups
*
2
,
sizeof
(
uint16_t
));
cudaMemcpy
(
cpu_q_groups
,
cuda_q_groups
,
groups
*
2
*
sizeof
(
uint16_t
),
cudaMemcpyDeviceToHost
);
cudaMemcpy
(
cpu_q_groups
,
cuda_q_groups
,
groups
*
2
*
sizeof
(
uint16_t
),
cudaMemcpyDeviceToHost
);
int
row
=
0
;
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
{
int
bits
=
cpu_q_groups
[
i
*
2
];
int
bits
=
cpu_q_groups
[
i
*
2
];
if
(
bits
==
8
)
rows_8
+=
groupsize
;
if
(
bits
==
6
)
rows_6
+=
groupsize
;
int
rows
;
if
(
bits
==
5
)
rows_5
+=
groupsize
;
if
(
i
<
groups
-
1
)
if
(
bits
==
4
)
rows_4
+=
groupsize
;
{
if
(
bits
==
3
)
rows_3
+=
groupsize
;
int
qrows
=
cpu_q_groups
[
i
*
2
+
3
]
-
cpu_q_groups
[
i
*
2
+
1
];
if
(
bits
==
2
)
rows_2
+=
groupsize
;
rows
=
qrows
*
32
/
bits
;
}
else
rows
=
height
-
row
;
if
(
bits
==
8
)
rows_8
+=
rows
;
if
(
bits
==
6
)
rows_6
+=
rows
;
if
(
bits
==
5
)
rows_5
+=
rows
;
if
(
bits
==
4
)
rows_4
+=
rows
;
if
(
bits
==
3
)
rows_3
+=
rows
;
if
(
bits
==
2
)
rows_2
+=
rows
;
row
+=
rows
;
}
}
free
(
cpu_q_groups
);
free
(
cpu_q_groups
);
...
@@ -138,6 +154,13 @@ QMatrix::QMatrix
...
@@ -138,6 +154,13 @@ QMatrix::QMatrix
}
}
}
}
// DBGI(rows_8);
// DBGI(rows_6);
// DBGI(rows_5);
// DBGI(rows_4);
// DBGI(rows_3);
// DBGI(rows_2);
// Shuffle quantized data
// Shuffle quantized data
dim3
blockDim
,
gridDim
;
dim3
blockDim
,
gridDim
;
...
@@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
...
@@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
const
uint16_t
*
__restrict__
b_q_perm
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_q_scale
,
const
uint32_t
*
__restrict__
b_q_scale
,
const
half
*
__restrict__
b_q_scale_max
,
const
half
*
__restrict__
b_q_scale_max
,
//
const uint16_t* __restrict__ b_q_group
s
,
const
uint16_t
*
__restrict__
b_q_group
_map
,
const
int
size_k
,
const
int
size_k
,
const
int
size_n
,
const
int
size_n
,
const
int
groupsize
,
//
const int groupsize,
const
int
groups
,
const
int
groups
,
half
*
__restrict__
b
,
half
*
__restrict__
b
,
const
int
rows_8
,
const
int
rows_8
,
...
@@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
...
@@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
// Find initial group
// Find initial group
int
group
=
offset_k
/
groupsize
;
// int group = offset_k / groupsize;
int
group
=
b_q_group_map
[
offset_k
*
2
];
int
pre_rows_8
=
min
(
rows_8
,
offset_k
);
int
pre_rows_8
=
min
(
rows_8
,
offset_k
);
int
pre_rows_6
=
offset_k
>
rows_8
?
min
(
rows_6
,
offset_k
)
-
rows_8
:
0
;
int
pre_rows_6
=
offset_k
>
rows_8
?
min
(
rows_6
,
offset_k
)
-
rows_8
:
0
;
...
@@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
...
@@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
half
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
half
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
half2
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
half2
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
int
nextgroup
=
offset_k
+
group
size
;
int
nextgroup
=
offset_k
+
b_q_
group
_map
[
offset_k
*
2
+
1
]
;
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
k
=
offset_k
;
int
k
=
offset_k
;
...
@@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
...
@@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
while
(
k
<
rows_8
&&
k
<
end_k
)
while
(
k
<
rows_8
&&
k
<
end_k
)
{
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
group
size
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
for
(
int
p
=
0
;
p
<
4
;
p
++
)
{
{
half2
dq
[
4
];
half2
dq
[
4
];
...
@@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
...
@@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
while
(
k
<
rows_6
&&
k
<
end_k
)
while
(
k
<
rows_6
&&
k
<
end_k
)
{
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
group
size
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
2
;
p
++
)
for
(
int
p
=
0
;
p
<
2
;
p
++
)
{
{
half2
dq
[
8
];
half2
dq
[
8
];
...
@@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
...
@@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
while
(
k
<
rows_5
&&
k
<
end_k
)
while
(
k
<
rows_5
&&
k
<
end_k
)
{
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
group
size
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
1
;
p
++
)
for
(
int
p
=
0
;
p
<
1
;
p
++
)
{
{
half2
dq
[
16
];
half2
dq
[
16
];
...
@@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
...
@@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
while
(
k
<
rows_4
&&
k
<
end_k
)
while
(
k
<
rows_4
&&
k
<
end_k
)
{
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
group
size
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
for
(
int
p
=
0
;
p
<
4
;
p
++
)
{
{
half2
dq
[
4
];
half2
dq
[
4
];
...
@@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
...
@@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
while
(
k
<
rows_3
&&
k
<
end_k
)
while
(
k
<
rows_3
&&
k
<
end_k
)
{
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
group
size
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
1
;
p
++
)
for
(
int
p
=
0
;
p
<
1
;
p
++
)
{
{
half2
dq
[
16
];
half2
dq
[
16
];
...
@@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
...
@@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
while
(
k
<
rows_2
&&
k
<
end_k
)
while
(
k
<
rows_2
&&
k
<
end_k
)
{
{
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
group
size
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
if
(
k
==
nextgroup
)
{
group
++
;
qs_h
=
dq_scale
(
b_q_scale_
.
item
(
group
,
n
),
b_q_scale_max
[
group
]);
nextgroup
+=
b_q_
group
_map
[
k
*
2
+
1
]
;
qs_h2
=
__halves2half2
(
qs_h
,
qs_h
);
}
for
(
int
p
=
0
;
p
<
2
;
p
++
)
for
(
int
p
=
0
;
p
<
1
;
p
++
)
{
{
half2
dq
[
8
];
half2
dq
[
8
];
uint32_t
q_0
=
*
b_ptr
;
b_ptr
+=
size_n
;
uint32_t
q_0
=
*
b_ptr
;
b_ptr
+=
size_n
;
...
@@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
...
@@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
half
*
dqh
=
(
half
*
)
dq
;
half
*
dqh
=
(
half
*
)
dq
;
for
(
int
j
=
0
;
j
<
16
;
j
++
)
b_
.
set
(
perm
[
lk
++
],
n
,
dqh
[
j
]);
for
(
int
j
=
0
;
j
<
16
;
j
++
)
b_
.
set
(
perm
[
lk
++
],
n
,
dqh
[
j
]);
}
}
k
+=
32
;
k
+=
16
;
}
}
}
}
...
@@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
...
@@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
cuda_q_perm
,
cuda_q_perm
,
cuda_q_scale
,
cuda_q_scale
,
cuda_q_scale_max
,
cuda_q_scale_max
,
//
cuda_q_group
s
,
cuda_q_group
_map
,
height
,
height
,
width
,
width
,
groupsize
,
//
groupsize,
groups
,
groups
,
out
,
out
,
rows_8
,
rows_8
,
...
@@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
...
@@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
//const uint16_t* __restrict__ b_q_groups,
//const uint16_t* __restrict__ b_q_groups,
height
,
height
,
width
,
width
,
groupsize
,
gptq_
groupsize
,
groups
,
groups
,
out
,
out
,
rows_4
rows_4
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh
View file @
564199ba
...
@@ -18,7 +18,7 @@ public:
...
@@ -18,7 +18,7 @@ public:
int
height
;
int
height
;
int
width
;
int
width
;
int
groups
;
int
groups
;
int
groupsize
;
int
gptq_
groupsize
;
int
rows_8
;
int
rows_8
;
int
rows_6
;
int
rows_6
;
...
@@ -33,6 +33,7 @@ public:
...
@@ -33,6 +33,7 @@ public:
uint32_t
*
cuda_q_scale
=
NULL
;
uint32_t
*
cuda_q_scale
=
NULL
;
half
*
cuda_q_scale_max
=
NULL
;
half
*
cuda_q_scale_max
=
NULL
;
uint16_t
*
cuda_q_groups
=
NULL
;
uint16_t
*
cuda_q_groups
=
NULL
;
uint16_t
*
cuda_q_group_map
=
NULL
;
uint32_t
*
cuda_gptq_qzeros
=
NULL
;
uint32_t
*
cuda_gptq_qzeros
=
NULL
;
half
*
cuda_gptq_scales
=
NULL
;
half
*
cuda_gptq_scales
=
NULL
;
...
@@ -53,6 +54,7 @@ public:
...
@@ -53,6 +54,7 @@ public:
uint32_t
*
_q_scale
,
uint32_t
*
_q_scale
,
half
*
_q_scale_max
,
half
*
_q_scale_max
,
uint16_t
*
_q_groups
,
uint16_t
*
_q_groups
,
uint16_t
*
_q_group_map
,
uint32_t
*
_gptq_qzeros
,
uint32_t
*
_gptq_qzeros
,
half
*
_gptq_scales
,
half
*
_gptq_scales
,
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh
View file @
564199ba
...
@@ -7,6 +7,7 @@ union half2_uint32
...
@@ -7,6 +7,7 @@ union half2_uint32
half2
as_half2
;
half2
as_half2
;
__device__
half2_uint32
(
uint32_t
val
)
:
as_uint32
(
val
)
{}
__device__
half2_uint32
(
uint32_t
val
)
:
as_uint32
(
val
)
{}
__device__
half2_uint32
(
half2
val
)
:
as_half2
(
val
)
{}
__device__
half2_uint32
(
half2
val
)
:
as_half2
(
val
)
{}
__device__
half2_uint32
()
:
as_uint32
(
0
)
{}
};
};
union
half_uint16
union
half_uint16
...
@@ -15,6 +16,7 @@ union half_uint16
...
@@ -15,6 +16,7 @@ union half_uint16
half
as_half
;
half
as_half
;
__device__
half_uint16
(
uint16_t
val
)
:
as_uint16
(
val
)
{}
__device__
half_uint16
(
uint16_t
val
)
:
as_uint16
(
val
)
{}
__device__
half_uint16
(
half
val
)
:
as_half
(
val
)
{}
__device__
half_uint16
(
half
val
)
:
as_half
(
val
)
{}
__device__
half_uint16
()
:
as_uint16
(
0
)
{}
};
};
// Max_scale premultiplied by 1/256
// Max_scale premultiplied by 1/256
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh
View file @
564199ba
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
...
@@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
...
@@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
if
(
abort
)
exit
(
code
);
if
(
abort
)
exit
(
code
);
}
}
}
}
void
print_global_mem
(
const
half
*
ptr
,
int
rows
,
int
columns
,
int
stride
);
#endif
\ No newline at end of file
server/exllamav2_kernels/exllamav2_kernels/ext.cpp
View file @
564199ba
...
@@ -31,6 +31,7 @@ uintptr_t make_q_matrix
...
@@ -31,6 +31,7 @@ uintptr_t make_q_matrix
torch
::
Tensor
q_scale
,
torch
::
Tensor
q_scale
,
torch
::
Tensor
q_scale_max
,
torch
::
Tensor
q_scale_max
,
torch
::
Tensor
q_groups
,
torch
::
Tensor
q_groups
,
torch
::
Tensor
q_group_map
,
torch
::
Tensor
gptq_qzeros
,
torch
::
Tensor
gptq_qzeros
,
torch
::
Tensor
gptq_scales
,
torch
::
Tensor
gptq_scales
,
torch
::
Tensor
gptq_g_idx
,
torch
::
Tensor
gptq_g_idx
,
...
@@ -43,6 +44,7 @@ uintptr_t make_q_matrix
...
@@ -43,6 +44,7 @@ uintptr_t make_q_matrix
TORCH_CHECK_DTYPE_OPT
(
q_scale
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
q_scale
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
q_scale_max
,
kHalf
);
TORCH_CHECK_DTYPE_OPT
(
q_scale_max
,
kHalf
);
TORCH_CHECK_DTYPE_OPT
(
q_groups
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
q_groups
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
q_group_map
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
gptq_qzeros
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
gptq_qzeros
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
gptq_scales
,
kHalf
);
TORCH_CHECK_DTYPE_OPT
(
gptq_scales
,
kHalf
);
TORCH_CHECK_DTYPE_OPT
(
gptq_g_idx
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
gptq_g_idx
,
kInt
);
...
@@ -83,12 +85,15 @@ uintptr_t make_q_matrix
...
@@ -83,12 +85,15 @@ uintptr_t make_q_matrix
q_scale
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
q_scale
.
data_ptr
(),
q_scale
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
q_scale
.
data_ptr
(),
q_scale_max
.
device
().
is_meta
()
?
NULL
:
(
half
*
)
q_scale_max
.
data_ptr
(),
q_scale_max
.
device
().
is_meta
()
?
NULL
:
(
half
*
)
q_scale_max
.
data_ptr
(),
q_groups
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_groups
.
data_ptr
(),
q_groups
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_groups
.
data_ptr
(),
q_group_map
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_group_map
.
data_ptr
(),
gptq_qzeros
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
gptq_qzeros
.
data_ptr
(),
gptq_qzeros
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
gptq_qzeros
.
data_ptr
(),
gptq_scales
.
device
().
is_meta
()
?
NULL
:
(
half
*
)
gptq_scales
.
data_ptr
(),
gptq_scales
.
device
().
is_meta
()
?
NULL
:
(
half
*
)
gptq_scales
.
data_ptr
(),
gptq_g_idx
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
gptq_g_idx
.
data_ptr
(),
gptq_g_idx
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
gptq_g_idx
.
data_ptr
(),
(
half
*
)
temp_dq
.
data_ptr
()
(
half
*
)
temp_dq
.
data_ptr
()
);
);
if
(
m
->
failed
)
throw
std
::
runtime_error
(
"CUDA out of memory"
);
return
reinterpret_cast
<
uintptr_t
>
(
m
);
return
reinterpret_cast
<
uintptr_t
>
(
m
);
}
}
...
...
server/tests/utils/test_hub.py
View file @
564199ba
...
@@ -32,10 +32,10 @@ def fresh_cache():
...
@@ -32,10 +32,10 @@ def fresh_cache():
current_value
=
huggingface_hub
.
constants
.
HUGGINGFACE_HUB_CACHE
current_value
=
huggingface_hub
.
constants
.
HUGGINGFACE_HUB_CACHE
huggingface_hub
.
constants
.
HUGGINGFACE_HUB_CACHE
=
d
huggingface_hub
.
constants
.
HUGGINGFACE_HUB_CACHE
=
d
text_generation_server
.
utils
.
hub
.
HUGGINGFACE_HUB_CACHE
=
d
text_generation_server
.
utils
.
hub
.
HUGGINGFACE_HUB_CACHE
=
d
os
.
environ
[
'
HUGGINGFACE_HUB_CACHE
'
]
=
d
os
.
environ
[
"
HUGGINGFACE_HUB_CACHE
"
]
=
d
yield
yield
huggingface_hub
.
constants
.
HUGGINGFACE_HUB_CACHE
=
current_value
huggingface_hub
.
constants
.
HUGGINGFACE_HUB_CACHE
=
current_value
os
.
environ
[
'
HUGGINGFACE_HUB_CACHE
'
]
=
current_value
os
.
environ
[
"
HUGGINGFACE_HUB_CACHE
"
]
=
current_value
text_generation_server
.
utils
.
hub
.
HUGGINGFACE_HUB_CACHE
=
current_value
text_generation_server
.
utils
.
hub
.
HUGGINGFACE_HUB_CACHE
=
current_value
...
@@ -47,7 +47,7 @@ def prefetched():
...
@@ -47,7 +47,7 @@ def prefetched():
revision
=
"main"
,
revision
=
"main"
,
local_files_only
=
False
,
local_files_only
=
False
,
repo_type
=
"model"
,
repo_type
=
"model"
,
allow_patterns
=
[
"*.safetensors"
]
allow_patterns
=
[
"*.safetensors"
]
,
)
)
yield
model_id
yield
model_id
...
@@ -61,7 +61,7 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
...
@@ -61,7 +61,7 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
def
test_weight_hub_files_offline_ok
(
prefetched
,
offline
):
def
test_weight_hub_files_offline_ok
(
prefetched
,
offline
):
# If the model is prefetched then we should be able to get the weight files from local cache
# If the model is prefetched then we should be able to get the weight files from local cache
filenames
=
weight_hub_files
(
prefetched
)
filenames
=
weight_hub_files
(
prefetched
)
assert
filenames
==
[
'
model.safetensors
'
]
assert
filenames
==
[
"
model.safetensors
"
]
def
test_weight_hub_files
():
def
test_weight_hub_files
():
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
564199ba
...
@@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
...
@@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
g_idx
=
weights
.
get_tensor
(
f
"
{
prefix
}
.c_attn.g_idx"
)
g_idx
=
weights
.
get_tensor
(
f
"
{
prefix
}
.c_attn.g_idx"
)
g_idx
=
g_idx
.
to
(
device
=
weights
.
device
)
g_idx
=
g_idx
.
to
(
device
=
weights
.
device
)
bits
,
groupsize
=
weights
.
_get_gptq_params
()
bits
,
groupsize
,
_
=
weights
.
_get_gptq_params
()
from
text_generation_server.utils.layers
import
HAS_EXLLAMA
from
text_generation_server.utils.layers
import
HAS_EXLLAMA
...
...
server/text_generation_server/utils/gptq/exllamav2.py
View file @
564199ba
...
@@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
...
@@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
return
output
.
view
(
output_shape
)
return
output
.
view
(
output_shape
)
# Group map needed for irregular group sizes
def
make_group_map
(
q_groups
,
num_qrows
):
gr
=
q_groups
.
tolist
()
group_map
=
[]
num_groups
=
len
(
gr
)
//
2
for
i
in
range
(
num_groups
):
bits
=
gr
[
i
*
2
]
if
i
<
num_groups
-
1
:
qrows
=
gr
[
i
*
2
+
3
]
-
gr
[
i
*
2
+
1
]
else
:
qrows
=
num_qrows
-
gr
[
i
*
2
+
1
]
rows
=
qrows
*
32
//
bits
for
j
in
range
(
rows
):
group_map
+=
[
i
]
group_map
+=
[
rows
-
j
]
return
torch
.
tensor
(
group_map
,
dtype
=
torch
.
short
,
device
=
q_groups
.
device
)
# Create Q matrix
def
ext_make_q_matrix
(
w
:
dict
,
temp_dq
,
key
:
str
=
None
):
def
ext_make_q_matrix
(
w
:
dict
,
temp_dq
,
key
:
str
=
None
):
"""
"""
Create Q matrix
Create Q matrix
...
@@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
...
@@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w
[
"q_scale_max"
]
/=
256
w
[
"q_scale_max"
]
/=
256
w
[
"q_perm"
]
=
w
[
"q_perm"
].
short
()
w
[
"q_perm"
]
=
w
[
"q_perm"
].
short
()
w
[
"q_invperm"
]
=
w
[
"q_invperm"
].
short
()
w
[
"q_invperm"
]
=
w
[
"q_invperm"
].
short
()
if
"q_group_map"
not
in
w
:
w
[
"q_group_map"
]
=
make_group_map
(
w
[
"q_groups"
],
w
[
"q_weight"
].
shape
[
0
])
return
make_q_matrix
(
return
make_q_matrix
(
w
[
"q_weight"
],
w
[
"q_weight"
],
w
[
"q_perm"
],
w
[
"q_perm"
],
...
@@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
...
@@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w
[
"q_scale"
],
w
[
"q_scale"
],
w
[
"q_scale_max"
],
w
[
"q_scale_max"
],
w
[
"q_groups"
],
w
[
"q_groups"
],
w
[
"q_group_map"
],
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
...
@@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
...
@@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
w
[
"qzeros"
],
w
[
"qzeros"
],
w
[
"scales"
],
w
[
"scales"
],
w
[
"g_idx"
].
cpu
(),
w
[
"g_idx"
].
cpu
(),
...
@@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
...
@@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
none_tensor
,
w
[
"qzeros"
],
w
[
"qzeros"
],
w
[
"scales"
],
w
[
"scales"
],
none_tensor
,
none_tensor
,
...
...
server/text_generation_server/utils/hub.py
View file @
564199ba
...
@@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
...
@@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE
=
os
.
environ
.
get
(
"HF_HUB_OFFLINE"
,
"0"
).
lower
()
in
[
"true"
,
"1"
,
"yes"
]
HF_HUB_OFFLINE
=
os
.
environ
.
get
(
"HF_HUB_OFFLINE"
,
"0"
).
lower
()
in
[
"true"
,
"1"
,
"yes"
]
def
_cached_weight_files
(
model_id
:
str
,
revision
:
Optional
[
str
],
extension
:
str
)
->
List
[
str
]:
def
_cached_weight_files
(
model_id
:
str
,
revision
:
Optional
[
str
],
extension
:
str
)
->
List
[
str
]:
"""Guess weight files from the cached revision snapshot directory"""
"""Guess weight files from the cached revision snapshot directory"""
d
=
_get_cached_revision_directory
(
model_id
,
revision
)
d
=
_get_cached_revision_directory
(
model_id
,
revision
)
if
not
d
:
if
not
d
:
...
@@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
...
@@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
return
filenames
return
filenames
def
_weight_hub_files_from_model_info
(
info
:
hf_api
.
ModelInfo
,
extension
:
str
)
->
List
[
str
]:
def
_weight_hub_files_from_model_info
(
info
:
hf_api
.
ModelInfo
,
extension
:
str
)
->
List
[
str
]:
return
[
return
[
s
.
rfilename
s
.
rfilename
for
s
in
info
.
siblings
for
s
in
info
.
siblings
...
@@ -44,21 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
...
@@ -44,21 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
# see _weight_hub_files_from_model_info, that's also what is
# see _weight_hub_files_from_model_info, that's also what is
# done there with the len(s.rfilename.split("/")) == 1 condition
# done there with the len(s.rfilename.split("/")) == 1 condition
root
,
_
,
files
=
next
(
os
.
walk
(
str
(
d
)))
root
,
_
,
files
=
next
(
os
.
walk
(
str
(
d
)))
filenames
=
[
f
for
f
in
files
filenames
=
[
if
f
.
endswith
(
extension
)
f
and
"arguments"
not
in
f
for
f
in
files
and
"args"
not
in
f
if
f
.
endswith
(
extension
)
and
"adapter"
not
in
f
and
"arguments"
not
in
f
and
"training"
not
in
f
]
and
"args"
not
in
f
and
"adapter"
not
in
f
and
"training"
not
in
f
]
return
filenames
return
filenames
def
_get_cached_revision_directory
(
model_id
:
str
,
revision
:
Optional
[
str
])
->
Optional
[
Path
]:
def
_get_cached_revision_directory
(
model_id
:
str
,
revision
:
Optional
[
str
]
)
->
Optional
[
Path
]:
if
revision
is
None
:
if
revision
is
None
:
revision
=
"main"
revision
=
"main"
repo_cache
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
Path
(
repo_cache
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
Path
(
file_download
.
repo_folder_name
(
repo_id
=
model_id
,
repo_type
=
"model"
))
file_download
.
repo_folder_name
(
repo_id
=
model_id
,
repo_type
=
"model"
)
)
if
not
repo_cache
.
is_dir
():
if
not
repo_cache
.
is_dir
():
# No cache for this model
# No cache for this model
...
@@ -86,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
...
@@ -86,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
def
weight_hub_files
(
def
weight_hub_files
(
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Get the weights filenames on the hub"""
"""Get the weights filenames on the hub"""
api
=
HfApi
()
api
=
HfApi
()
...
...
server/text_generation_server/utils/layers.py
View file @
564199ba
...
@@ -19,6 +19,7 @@ from accelerate import init_empty_weights
...
@@ -19,6 +19,7 @@ from accelerate import init_empty_weights
from
text_generation_server.utils.gptq.quant_linear
import
QuantLinear
from
text_generation_server.utils.gptq.quant_linear
import
QuantLinear
from
text_generation_server.utils.import_utils
import
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
from
text_generation_server.utils.import_utils
import
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
from
text_generation_server.utils.log
import
log_once
HAS_AWQ
=
True
HAS_AWQ
=
True
try
:
try
:
...
@@ -35,10 +36,11 @@ HAS_EXLLAMA = False
...
@@ -35,10 +36,11 @@ HAS_EXLLAMA = False
CAN_EXLLAMA
=
major
>=
8
CAN_EXLLAMA
=
major
>=
8
V2
=
os
.
getenv
(
"EXLLAMA_VERSION"
,
"2"
)
==
"2"
V2
=
os
.
getenv
(
"EXLLAMA_VERSION"
,
"2"
)
==
"2"
if
V2
and
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
>
1
:
if
V2
and
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
>
1
:
logger
.
warning
(
V2
=
False
log_once
(
logger
.
warning
,
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
)
)
V2
=
False
if
os
.
getenv
(
"DISABLE_EXLLAMA"
)
==
"True"
:
if
os
.
getenv
(
"DISABLE_EXLLAMA"
)
==
"True"
:
HAS_EXLLAMA
=
False
HAS_EXLLAMA
=
False
...
...
server/text_generation_server/utils/log.py
0 → 100644
View file @
564199ba
from
functools
import
lru_cache
@
lru_cache
(
10
)
def
log_once
(
log
,
msg
:
str
):
log
(
msg
)
server/text_generation_server/utils/weights.py
View file @
564199ba
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
loguru
import
logger
from
loguru
import
logger
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
import
json
import
json
from
text_generation_server.utils.log
import
log_once
class
Weights
:
class
Weights
:
...
@@ -161,7 +162,7 @@ class Weights:
...
@@ -161,7 +162,7 @@ class Weights:
else
:
else
:
g_idx
=
None
g_idx
=
None
bits
,
groupsize
=
self
.
_get_gptq_params
()
bits
,
groupsize
,
_
=
self
.
_get_gptq_params
()
weight
=
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
,
groupsize
,
False
)
weight
=
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
,
groupsize
,
False
)
else
:
else
:
slice_
=
self
.
_get_slice
(
f
"
{
prefix
}
.weight"
)
slice_
=
self
.
_get_slice
(
f
"
{
prefix
}
.weight"
)
...
@@ -211,10 +212,10 @@ class Weights:
...
@@ -211,10 +212,10 @@ class Weights:
else
:
else
:
g_idx
=
None
g_idx
=
None
bits
,
groupsize
=
self
.
_get_gptq_params
()
bits
,
groupsize
,
desc_act
=
self
.
_get_gptq_params
()
from
text_generation_server.utils.layers
import
HAS_EXLLAMA
from
text_generation_server.utils.layers
import
HAS_EXLLAMA
use_exllama
=
bits
==
4
and
HAS_EXLLAMA
and
quantize
==
"gptq"
use_exllama
=
bits
==
4
and
HAS_EXLLAMA
and
quantize
==
"gptq"
and
not
desc_act
weight
=
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
,
groupsize
,
use_exllama
)
weight
=
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
,
groupsize
,
use_exllama
)
else
:
else
:
w
=
[
self
.
get_sharded
(
f
"
{
p
}
.weight"
,
dim
=
0
)
for
p
in
prefixes
]
w
=
[
self
.
get_sharded
(
f
"
{
p
}
.weight"
,
dim
=
0
)
for
p
in
prefixes
]
...
@@ -240,11 +241,15 @@ class Weights:
...
@@ -240,11 +241,15 @@ class Weights:
def
get_multi_weights_row
(
self
,
prefix
:
str
,
quantize
:
str
):
def
get_multi_weights_row
(
self
,
prefix
:
str
,
quantize
:
str
):
if
quantize
==
"gptq"
:
if
quantize
==
"gptq"
:
use_exllama
=
True
use_exllama
=
True
bits
,
groupsize
=
self
.
_get_gptq_params
()
bits
,
groupsize
,
desc_act
=
self
.
_get_gptq_params
()
if
bits
!=
4
:
if
bits
!=
4
:
use_exllama
=
False
use_exllama
=
False
if
desc_act
:
log_once
(
logger
.
warning
,
"Disabling exllama because desc_act=True"
)
use_exllama
=
False
if
self
.
process_group
.
size
()
>
1
:
if
self
.
process_group
.
size
()
>
1
:
g_idx
=
self
.
get_tensor
(
f
"
{
prefix
}
.g_idx"
)
g_idx
=
self
.
get_tensor
(
f
"
{
prefix
}
.g_idx"
)
if
g_idx
is
not
None
:
if
g_idx
is
not
None
:
...
@@ -274,12 +279,18 @@ class Weights:
...
@@ -274,12 +279,18 @@ class Weights:
if
use_exllama
:
if
use_exllama
:
if
not
HAS_EXLLAMA
:
if
not
HAS_EXLLAMA
:
if
CAN_EXLLAMA
:
if
CAN_EXLLAMA
:
logger
.
warning
(
log_once
(
logger
.
warning
,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
)
)
use_exllama
=
False
use_exllama
=
False
else
:
else
:
logger
.
info
(
f
"Using exllama kernels v
{
HAS_EXLLAMA
}
"
)
log_once
(
logger
.
info
,
f
"Using exllama kernels v
{
HAS_EXLLAMA
}
"
)
g_idx
=
self
.
get_sharded
(
f
"
{
prefix
}
.g_idx"
,
dim
=
0
)
if
use_exllama
and
groupsize
!=
-
1
:
if
use_exllama
and
groupsize
!=
-
1
:
qzeros
=
self
.
get_sharded
(
f
"
{
prefix
}
.qzeros"
,
dim
=
0
)
qzeros
=
self
.
get_sharded
(
f
"
{
prefix
}
.qzeros"
,
dim
=
0
)
...
@@ -288,14 +299,12 @@ class Weights:
...
@@ -288,14 +299,12 @@ class Weights:
qzeros
=
self
.
get_tensor
(
f
"
{
prefix
}
.qzeros"
)
qzeros
=
self
.
get_tensor
(
f
"
{
prefix
}
.qzeros"
)
scales
=
self
.
get_tensor
(
f
"
{
prefix
}
.scales"
)
scales
=
self
.
get_tensor
(
f
"
{
prefix
}
.scales"
)
g_idx
=
self
.
get_sharded
(
f
"
{
prefix
}
.g_idx"
,
dim
=
0
)
if
use_exllama
:
if
use_exllama
:
g_idx
=
g_idx
-
g_idx
[
0
]
g_idx
=
g_idx
-
g_idx
[
0
]
weight
=
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
,
groupsize
,
use_exllama
)
weight
=
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
,
groupsize
,
use_exllama
)
elif
quantize
==
"awq"
:
elif
quantize
==
"awq"
:
bits
,
groupsize
=
self
.
_get_gptq_params
()
bits
,
groupsize
,
_
=
self
.
_get_gptq_params
()
try
:
try
:
qweight
=
self
.
get_sharded
(
f
"
{
prefix
}
.qweight"
,
dim
=
0
)
qweight
=
self
.
get_sharded
(
f
"
{
prefix
}
.qweight"
,
dim
=
0
)
...
@@ -314,18 +323,20 @@ class Weights:
...
@@ -314,18 +323,20 @@ class Weights:
weight
=
self
.
get_sharded
(
f
"
{
prefix
}
.weight"
,
dim
=
1
)
weight
=
self
.
get_sharded
(
f
"
{
prefix
}
.weight"
,
dim
=
1
)
return
weight
return
weight
def
_get_gptq_params
(
self
)
->
Tuple
[
int
,
int
]:
def
_get_gptq_params
(
self
)
->
Tuple
[
int
,
int
,
int
]:
try
:
try
:
bits
=
self
.
get_tensor
(
"gptq_bits"
).
item
()
bits
=
self
.
get_tensor
(
"gptq_bits"
).
item
()
groupsize
=
self
.
get_tensor
(
"gptq_groupsize"
).
item
()
groupsize
=
self
.
get_tensor
(
"gptq_groupsize"
).
item
()
desc_act
=
False
except
(
SafetensorError
,
RuntimeError
)
as
e
:
except
(
SafetensorError
,
RuntimeError
)
as
e
:
try
:
try
:
bits
=
self
.
gptq_bits
bits
=
self
.
gptq_bits
groupsize
=
self
.
gptq_groupsize
groupsize
=
self
.
gptq_groupsize
desc_act
=
getattr
(
self
,
"gptq_desc_act"
,
False
)
except
Exception
:
except
Exception
:
raise
e
raise
e
return
bits
,
groupsize
return
bits
,
groupsize
,
desc_act
def
_set_gptq_params
(
self
,
model_id
,
revision
):
def
_set_gptq_params
(
self
,
model_id
,
revision
):
filename
=
"config.json"
filename
=
"config.json"
...
@@ -340,6 +351,7 @@ class Weights:
...
@@ -340,6 +351,7 @@ class Weights:
data
=
json
.
load
(
f
)
data
=
json
.
load
(
f
)
self
.
gptq_bits
=
data
[
"quantization_config"
][
"bits"
]
self
.
gptq_bits
=
data
[
"quantization_config"
][
"bits"
]
self
.
gptq_groupsize
=
data
[
"quantization_config"
][
"group_size"
]
self
.
gptq_groupsize
=
data
[
"quantization_config"
][
"group_size"
]
self
.
gptq_desc_act
=
data
[
"quantization_config"
][
"desc_act"
]
except
Exception
:
except
Exception
:
filename
=
"quantize_config.json"
filename
=
"quantize_config.json"
try
:
try
:
...
@@ -353,6 +365,7 @@ class Weights:
...
@@ -353,6 +365,7 @@ class Weights:
data
=
json
.
load
(
f
)
data
=
json
.
load
(
f
)
self
.
gptq_bits
=
data
[
"bits"
]
self
.
gptq_bits
=
data
[
"bits"
]
self
.
gptq_groupsize
=
data
[
"group_size"
]
self
.
gptq_groupsize
=
data
[
"group_size"
]
self
.
gptq_desc_act
=
data
[
"desc_act"
]
except
Exception
:
except
Exception
:
filename
=
"quant_config.json"
filename
=
"quant_config.json"
try
:
try
:
...
@@ -366,5 +379,6 @@ class Weights:
...
@@ -366,5 +379,6 @@ class Weights:
data
=
json
.
load
(
f
)
data
=
json
.
load
(
f
)
self
.
gptq_bits
=
data
[
"w_bit"
]
self
.
gptq_bits
=
data
[
"w_bit"
]
self
.
gptq_groupsize
=
data
[
"q_group_size"
]
self
.
gptq_groupsize
=
data
[
"q_group_size"
]
self
.
gptq_desc_act
=
data
[
"desc_act"
]
except
Exception
:
except
Exception
:
pass
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment