Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c25fe45
Unverified
Commit
7c25fe45
authored
Nov 23, 2024
by
kliuae
Committed by
GitHub
Nov 22, 2024
Browse files
[AMD] Add support for GGUF quantization on ROCm (#10254)
parent
02a43f82
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
234 additions
and
211 deletions
+234
-211
.buildkite/run-amd-test.sh
.buildkite/run-amd-test.sh
+0
-1
CMakeLists.txt
CMakeLists.txt
+1
-1
csrc/ops.h
csrc/ops.h
+2
-0
csrc/quantization/gguf/ggml-common.h
csrc/quantization/gguf/ggml-common.h
+16
-1
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+4
-2
csrc/quantization/gguf/mmq.cuh
csrc/quantization/gguf/mmq.cuh
+35
-35
csrc/quantization/gguf/mmvq.cuh
csrc/quantization/gguf/mmvq.cuh
+2
-2
csrc/quantization/gguf/vecdotq.cuh
csrc/quantization/gguf/vecdotq.cuh
+143
-143
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+28
-25
vllm/config.py
vllm/config.py
+1
-1
No files found.
.buildkite/run-amd-test.sh
View file @
7c25fe45
...
...
@@ -85,7 +85,6 @@ if [[ $commands == *" kernels "* ]]; then
--ignore=kernels/test_encoder_decoder_attn.py
\
--ignore=kernels/test_flash_attn.py
\
--ignore=kernels/test_flashinfer.py
\
--ignore=kernels/test_gguf.py
\
--ignore=kernels/test_int8_quant.py
\
--ignore=kernels/test_machete_gemm.py
\
--ignore=kernels/test_mamba_ssm.py
\
...
...
CMakeLists.txt
View file @
7c25fe45
...
...
@@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp"
)
...
...
@@ -237,7 +238,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
)
...
...
csrc/ops.h
View file @
7c25fe45
...
...
@@ -128,6 +128,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
int64_t
thx
,
int64_t
thy
);
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
#endif
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
...
...
@@ -138,6 +139,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
#ifndef USE_ROCM
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/gguf/ggml-common.h
View file @
7c25fe45
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE 32
#define WARP_SIZE
_GGUF
32
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
...
...
@@ -1112,4 +1112,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return
c
;
}
static
__device__
__forceinline__
uint32_t
__vcmpeq4
(
const
uint32_t
a
,
const
uint32_t
b
)
{
uint32_t
neq
=
a
^
b
;
return
!
(
neq
&
0xff000000
)
*
0xff000000
|
!
(
neq
&
0x00ff0000
)
*
0x00ff0000
|
!
(
neq
&
0x0000ff00
)
*
0x0000ff00
|
!
(
neq
&
0x000000ff
)
*
0x000000ff
;
}
static
__device__
__forceinline__
uint32_t
__vsub4
(
const
uint32_t
a
,
const
uint32_t
b
)
{
return
(
static_cast
<
uint8_t
>
(((
a
&
0xff000000
)
>>
24
)
-
((
b
&
0xff000000
)
>>
24
))
<<
24
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x00ff0000
)
>>
16
)
-
((
b
&
0x00ff0000
)
>>
16
))
<<
16
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x0000ff00
)
>>
8
)
-
((
b
&
0x0000ff00
)
>>
8
))
<<
8
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x000000ff
)
>>
0
)
-
((
b
&
0x000000ff
)
>>
0
))
<<
0
);
}
#endif // defined(USE_ROCM)
csrc/quantization/gguf/gguf_kernel.cu
View file @
7c25fe45
...
...
@@ -4,6 +4,8 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
...
...
@@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
{
amax
=
fmaxf
(
amax
,
__shfl_xor_sync
(
0xffffffff
,
amax
,
mask
,
32
));
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
mask
,
32
);
amax
=
fmaxf
(
amax
,
VLLM_SHFL_XOR_SYNC_WIDTH
(
amax
,
mask
,
32
));
sum
+=
VLLM_SHFL_XOR_SYNC_WIDTH
(
sum
,
mask
,
32
);
}
const
float
d
=
amax
/
127
;
...
...
csrc/quantization/gguf/mmq.cuh
View file @
7c25fe45
...
...
@@ -10,7 +10,7 @@ static __device__ __forceinline__ void mul_mat_q(
const
int
blocks_per_row_x
=
ncols_x
/
qk
;
const
int
blocks_per_col_y
=
nrows_y
/
QK8_1
;
const
int
blocks_per_warp
=
WARP_SIZE
/
qi
;
const
int
blocks_per_warp
=
WARP_SIZE
_GGUF
/
qi
;
const
int
&
ncols_dst
=
ncols_y
;
...
...
@@ -27,10 +27,10 @@ static __device__ __forceinline__ void mul_mat_q(
allocate_tiles
(
&
tile_x_ql
,
&
tile_x_dm
,
&
tile_x_qh
,
&
tile_x_sc
);
__shared__
int
tile_y_qs
[
mmq_x
*
WARP_SIZE
];
__shared__
half2
tile_y_ds
[
mmq_x
*
WARP_SIZE
/
QI8_1
];
__shared__
int
tile_y_qs
[
mmq_x
*
WARP_SIZE
_GGUF
];
__shared__
half2
tile_y_ds
[
mmq_x
*
WARP_SIZE
_GGUF
/
QI8_1
];
float
sum
[
mmq_y
/
WARP_SIZE
][
mmq_x
/
nwarps
]
=
{{
0.0
f
}};
float
sum
[
mmq_y
/
WARP_SIZE
_GGUF
][
mmq_x
/
nwarps
]
=
{{
0.0
f
}};
for
(
int
ib0
=
0
;
ib0
<
blocks_per_row_x
;
ib0
+=
blocks_per_warp
)
{
...
...
@@ -39,26 +39,26 @@ static __device__ __forceinline__ void mul_mat_q(
#pragma unroll
for
(
int
ir
=
0
;
ir
<
qr
;
++
ir
)
{
const
int
kqs
=
ir
*
WARP_SIZE
+
threadIdx
.
x
;
const
int
kqs
=
ir
*
WARP_SIZE
_GGUF
+
threadIdx
.
x
;
const
int
kbxd
=
kqs
/
QI8_1
;
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_x
;
i
+=
nwarps
)
{
const
int
col_y_eff
=
min
(
col_y_0
+
threadIdx
.
y
+
i
,
ncols_y
-
1
);
// to prevent out-of-bounds memory accesses
const
block_q8_1
*
by0
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
kbxd
];
const
int
index_y
=
(
threadIdx
.
y
+
i
)
*
WARP_SIZE
+
kqs
%
WARP_SIZE
;
const
int
index_y
=
(
threadIdx
.
y
+
i
)
*
WARP_SIZE
_GGUF
+
kqs
%
WARP_SIZE
_GGUF
;
tile_y_qs
[
index_y
]
=
get_int_from_int8_aligned
(
by0
->
qs
,
threadIdx
.
x
%
QI8_1
);
}
#pragma unroll
for
(
int
ids0
=
0
;
ids0
<
mmq_x
;
ids0
+=
nwarps
*
QI8_1
)
{
const
int
ids
=
(
ids0
+
threadIdx
.
y
*
QI8_1
+
threadIdx
.
x
/
(
WARP_SIZE
/
QI8_1
))
%
mmq_x
;
const
int
kby
=
threadIdx
.
x
%
(
WARP_SIZE
/
QI8_1
);
const
int
ids
=
(
ids0
+
threadIdx
.
y
*
QI8_1
+
threadIdx
.
x
/
(
WARP_SIZE
_GGUF
/
QI8_1
))
%
mmq_x
;
const
int
kby
=
threadIdx
.
x
%
(
WARP_SIZE
_GGUF
/
QI8_1
);
const
int
col_y_eff
=
min
(
col_y_0
+
ids
,
ncols_y
-
1
);
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
const
half2
*
dsi_src
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
ir
*
(
WARP_SIZE
/
QI8_1
)
+
kby
].
ds
;
half2
*
dsi_dst
=
&
tile_y_ds
[
ids
*
(
WARP_SIZE
/
QI8_1
)
+
kby
];
const
half2
*
dsi_src
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
ir
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
kby
].
ds
;
half2
*
dsi_dst
=
&
tile_y_ds
[
ids
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
kby
];
if
(
need_sum
)
{
*
dsi_dst
=
*
dsi_src
;
}
else
{
...
...
@@ -70,12 +70,12 @@ static __device__ __forceinline__ void mul_mat_q(
__syncthreads
();
// #pragma unroll // unrolling this loop causes too much register pressure
for
(
int
k
=
ir
*
WARP_SIZE
/
qr
;
k
<
(
ir
+
1
)
*
WARP_SIZE
/
qr
;
k
+=
vdr
)
{
for
(
int
k
=
ir
*
WARP_SIZE
_GGUF
/
qr
;
k
<
(
ir
+
1
)
*
WARP_SIZE
_GGUF
/
qr
;
k
+=
vdr
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
)
{
sum
[
i
/
WARP_SIZE
][
j
/
nwarps
]
+=
vec_dot
(
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
_GGUF
)
{
sum
[
i
/
WARP_SIZE
_GGUF
][
j
/
nwarps
]
+=
vec_dot
(
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
tile_y_qs
,
tile_y_ds
,
threadIdx
.
x
+
i
,
threadIdx
.
y
+
j
,
k
);
}
...
...
@@ -93,12 +93,12 @@ static __device__ __forceinline__ void mul_mat_q(
}
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
)
{
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
_GGUF
)
{
const
int
row_dst
=
row_dst_0
+
threadIdx
.
x
+
i
;
if
(
row_dst
>=
nrows_dst
)
{
continue
;
}
dst
[
col_dst
*
nrows_dst
+
row_dst
]
=
__float2half
(
sum
[
i
/
WARP_SIZE
][
j
/
nwarps
]);
dst
[
col_dst
*
nrows_dst
+
row_dst
]
=
__float2half
(
sum
[
i
/
WARP_SIZE
_GGUF
][
j
/
nwarps
]);
}
}
}
...
...
@@ -115,7 +115,7 @@ static __device__ __forceinline__ void mul_mat_q(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_0
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q4_0
,
2
)
#endif
mul_mat_q4_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -140,7 +140,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -165,7 +165,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_1
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q4_1
,
2
)
#endif
mul_mat_q4_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -190,7 +190,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -215,7 +215,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_0
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q5_0
,
2
)
#endif
mul_mat_q5_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -240,7 +240,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -265,7 +265,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_1
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q5_1
,
2
)
#endif
mul_mat_q5_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -289,7 +289,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -314,7 +314,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q8_0
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q8_0
,
2
)
#endif
mul_mat_q8_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -338,7 +338,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -363,7 +363,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q2_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q2_K
,
2
)
#endif
mul_mat_q2_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -387,7 +387,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -412,7 +412,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q3_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q3_K
,
2
)
#endif
mul_mat_q3_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -438,7 +438,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -463,7 +463,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q4_K
,
2
)
#endif
mul_mat_q4_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -487,7 +487,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -512,7 +512,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q5_K
,
2
)
#endif
mul_mat_q5_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -537,7 +537,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
@@ -562,7 +562,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q6_K
,
2
)
__launch_bounds__
(
WARP_SIZE
_GGUF
*
NWARPS_Q6_K
,
2
)
#endif
mul_mat_q6_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
...
...
@@ -586,7 +586,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
block_dims
(
WARP_SIZE
_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
...
...
csrc/quantization/gguf/mmvq.cuh
View file @
7c25fe45
...
...
@@ -28,8 +28,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
// sum up partial sums and write back result
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
{
tmp
+=
__shfl_xor_sync
(
0xffffffff
,
tmp
,
mask
,
32
);
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>
0
;
mask
>>=
1
)
{
tmp
+=
VLLM_SHFL_XOR_SYNC
(
tmp
,
mask
);
}
if
(
threadIdx
.
x
==
0
)
{
...
...
csrc/quantization/gguf/vecdotq.cuh
View file @
7c25fe45
...
...
@@ -43,7 +43,7 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q4_0_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
float
&
d4
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
...
...
@@ -68,7 +68,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q4_1_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
half2
&
dm4
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
...
...
@@ -95,7 +95,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q5_0_q8_1_impl
(
const
int
*
vl
,
const
int
*
vh
,
const
int
*
u
,
const
float
&
d5
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
...
...
@@ -128,7 +128,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q5_1_q8_1_impl
(
const
int
*
vl
,
const
int
*
vh
,
const
int
*
u
,
const
half2
&
dm5
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
...
...
@@ -162,7 +162,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q8_0_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
float
&
d8_0
,
const
float
&
d8_1
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
...
...
@@ -176,7 +176,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q8_1_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
half2
&
dm8
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
int
sumi
=
0
;
...
...
@@ -202,7 +202,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1_impl_mmvq
(
const
int
&
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
half2
&
dm2
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
...
...
@@ -230,7 +230,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
half2
&
dm2
,
const
float
&
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
int
sumi_d
=
0
;
int
sumi_m
=
0
;
...
...
@@ -267,7 +267,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1_impl_mmvq
(
const
int
&
vl
,
const
int
&
vh
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
int
&
scale_offset
,
const
float
&
d3
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
float
sumf
=
0.0
f
;
...
...
@@ -301,7 +301,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
scales
,
const
float
&
d3
,
const
float
&
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
...
...
@@ -326,7 +326,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1_impl_vmmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
...
...
@@ -351,7 +351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
half2
*
__restrict__
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
...
...
@@ -382,7 +382,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_impl_vmmq
(
const
int
*
__restrict__
vl
,
const
int
*
__restrict__
vh
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm5
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
...
...
@@ -413,7 +413,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
half2
*
__restrict__
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
...
...
@@ -445,7 +445,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1_impl_mmvq
(
const
int
&
vl
,
const
int
&
vh
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
scales
,
const
float
&
d
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
float
sumf
=
0.0
f
;
#pragma unroll
...
...
@@ -465,7 +465,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
sc
,
const
float
&
d6
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
float
sumf_d
=
0.0
f
;
#pragma unroll
...
...
@@ -507,8 +507,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
/
QI4_0
)
+
mmq_y
/
QI4_0
];
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI4_0
)
+
mmq_y
/
QI4_0
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
}
...
...
@@ -529,11 +529,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q4_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
// x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
]
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
// x_dmf[i * (WARP_SIZE
_GGUF
/QI4_0) + i / QI4_0 + kbx] = bxi->d;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI4_0
;
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI4_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
...
...
@@ -543,7 +543,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q4_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI4_0
)
+
i
/
QI4_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI4_0
)
+
i
/
QI4_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
...
...
@@ -559,13 +559,13 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q4_0_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
)
%
WARP_SIZE
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
+
QI4_0
)
%
WARP_SIZE
];
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
_GGUF
+
(
kyqs
+
l
)
%
WARP_SIZE
_GGUF
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
_GGUF
+
(
kyqs
+
l
+
QI4_0
)
%
WARP_SIZE
_GGUF
];
}
return
vec_dot_q4_0_q8_1_impl
<
VDR_Q4_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
],
u
,
x_dmf
[
i
*
(
WARP_SIZE
/
QI4_0
)
+
i
/
QI4_0
+
k
/
QI4_0
],
y_ds
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
/
QI8_1
)]);
(
&
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
],
u
,
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI4_0
)
+
i
/
QI4_0
+
k
/
QI4_0
],
y_ds
[
j
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
_GGUF
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q4_1_q8_1
(
...
...
@@ -587,8 +587,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_1
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
)
+
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI4_1
)
+
mmq_y
/
QI4_1
];
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
_GGUF
)
+
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI4_1
)
+
mmq_y
/
QI4_1
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
tile_x_dm
;
}
...
...
@@ -608,10 +608,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q4_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI4_1
;
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI4_1
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
...
...
@@ -621,7 +621,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q4_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI4_1
)
+
i
/
QI4_1
+
kbxd
]
=
bxi
->
dm
;
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI4_1
)
+
i
/
QI4_1
+
kbxd
]
=
bxi
->
dm
;
}
}
...
...
@@ -634,13 +634,13 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q4_1_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
)
%
WARP_SIZE
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
+
QI4_1
)
%
WARP_SIZE
];
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
_GGUF
+
(
kyqs
+
l
)
%
WARP_SIZE
_GGUF
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
_GGUF
+
(
kyqs
+
l
+
QI4_1
)
%
WARP_SIZE
_GGUF
];
}
return
vec_dot_q4_1_q8_1_impl
<
VDR_Q4_1_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
],
u
,
x_dm
[
i
*
(
WARP_SIZE
/
QI4_1
)
+
i
/
QI4_1
+
k
/
QI4_1
],
y_ds
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
/
QI8_1
)]);
(
&
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
],
u
,
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI4_1
)
+
i
/
QI4_1
+
k
/
QI4_1
],
y_ds
[
j
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
_GGUF
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q5_0_q8_1
(
...
...
@@ -664,8 +664,8 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
/
QI5_0
)
+
mmq_y
/
QI5_0
];
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI5_0
)
+
mmq_y
/
QI5_0
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
...
...
@@ -697,7 +697,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
qs0
|=
(
qh
<<
25
)
&
0x10000000
;
// 3 -> 28
qs0
=
__vsubss4
(
qs0
,
0x10101010
);
// subtract 16
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
+
0
]
=
qs0
;
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
2
*
k
+
0
]
=
qs0
;
int
qs1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
qs1
|=
(
qh
>>
12
)
&
0x00000010
;
// 16 -> 4
...
...
@@ -706,10 +706,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
qs1
|=
(
qh
<<
9
)
&
0x10000000
;
// 19 -> 28
qs1
=
__vsubss4
(
qs1
,
0x10101010
);
// subtract 16
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
+
1
]
=
qs1
;
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
2
*
k
+
1
]
=
qs1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI5_0
;
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI5_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
...
...
@@ -722,7 +722,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
const
block_q5_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI5_0
)
+
i
/
QI5_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI5_0
)
+
i
/
QI5_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
...
...
@@ -730,7 +730,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
const
int
index_bx
=
i
*
(
WARP_SIZE
/
QI5_0
)
+
i
/
QI5_0
+
k
/
QI5_0
;
const
int
index_bx
=
i
*
(
WARP_SIZE
_GGUF
/
QI5_0
)
+
i
/
QI5_0
+
k
/
QI5_0
;
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
...
...
@@ -738,12 +738,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q5_0_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
)
%
WARP_SIZE
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
+
QI5_0
)
%
WARP_SIZE
];
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
_GGUF
+
(
kyqs
+
l
)
%
WARP_SIZE
_GGUF
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
_GGUF
+
(
kyqs
+
l
+
QI5_0
)
%
WARP_SIZE
_GGUF
];
}
return
vec_dot_q8_0_q8_1_impl
<
QR5_0
*
VDR_Q5_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
],
u
,
x_dmf
[
index_bx
],
y_df
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
/
QI8_1
)]);
(
&
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
2
*
k
],
u
,
x_dmf
[
index_bx
],
y_df
[
j
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
_GGUF
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q5_1_q8_1
(
...
...
@@ -767,8 +767,8 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_1
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI5_1
)
+
mmq_y
/
QI5_1
];
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI5_1
)
+
mmq_y
/
QI5_1
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
...
...
@@ -801,7 +801,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
qs0
|=
(
qh
<<
18
)
&
0x00100000
;
// 2 -> 20
qs0
|=
(
qh
<<
25
)
&
0x10000000
;
// 3 -> 28
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
+
0
]
=
qs0
;
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
2
*
k
+
0
]
=
qs0
;
int
qs1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
qs1
|=
(
qh
>>
12
)
&
0x00000010
;
// 16 -> 4
...
...
@@ -809,10 +809,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
qs1
|=
(
qh
<<
2
)
&
0x00100000
;
// 18 -> 20
qs1
|=
(
qh
<<
9
)
&
0x10000000
;
// 19 -> 28
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
+
1
]
=
qs1
;
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
2
*
k
+
1
]
=
qs1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI5_1
;
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI5_1
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
...
...
@@ -825,7 +825,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const
block_q5_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI5_1
)
+
i
/
QI5_1
+
kbxd
]
=
bxi
->
dm
;
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI5_1
)
+
i
/
QI5_1
+
kbxd
]
=
bxi
->
dm
;
}
}
...
...
@@ -833,18 +833,18 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
const
int
index_bx
=
i
*
(
WARP_SIZE
/
QI5_1
)
+
+
i
/
QI5_1
+
k
/
QI5_1
;
const
int
index_bx
=
i
*
(
WARP_SIZE
_GGUF
/
QI5_1
)
+
+
i
/
QI5_1
+
k
/
QI5_1
;
int
u
[
2
*
VDR_Q5_1_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q5_1_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
)
%
WARP_SIZE
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
+
QI5_1
)
%
WARP_SIZE
];
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
_GGUF
+
(
kyqs
+
l
)
%
WARP_SIZE
_GGUF
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
_GGUF
+
(
kyqs
+
l
+
QI5_1
)
%
WARP_SIZE
_GGUF
];
}
return
vec_dot_q8_1_q8_1_impl
<
QR5_1
*
VDR_Q5_1_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
],
u
,
x_dm
[
index_bx
],
y_ds
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
/
QI8_1
)]);
(
&
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
2
*
k
],
u
,
x_dm
[
index_bx
],
y_ds
[
j
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
_GGUF
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q8_0_q8_1
(
...
...
@@ -865,8 +865,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q8_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
/
QI8_0
)
+
mmq_y
/
QI8_0
];
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI8_0
)
+
mmq_y
/
QI8_0
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
...
...
@@ -889,10 +889,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q8_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_int8
(
bxi
->
qs
,
kqsx
);
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
]
=
get_int_from_int8
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI8_0
;
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI8_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
...
...
@@ -903,7 +903,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q8_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI8_0
)
+
i
/
QI8_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI8_0
)
+
i
/
QI8_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
...
...
@@ -914,8 +914,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
return
vec_dot_q8_0_q8_1_impl
<
VDR_Q8_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
],
&
y_qs
[
j
*
WARP_SIZE
+
k
],
x_dmf
[
i
*
(
WARP_SIZE
/
QI8_0
)
+
i
/
QI8_0
+
k
/
QI8_0
],
y_df
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
k
/
QI8_1
]);
(
&
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
],
&
y_qs
[
j
*
WARP_SIZE
_GGUF
+
k
],
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI8_0
)
+
i
/
QI8_0
+
k
/
QI8_0
],
y_df
[
j
*
(
WARP_SIZE
_GGUF
/
QI8_1
)
+
k
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1
(
...
...
@@ -942,9 +942,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q2_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI2_K
)
+
mmq_y
/
QI2_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
4
)
+
mmq_y
/
4
];
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI2_K
)
+
mmq_y
/
QI2_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
4
)
+
mmq_y
/
4
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
...
...
@@ -967,10 +967,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI2_K
;
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI2_K
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
...
...
@@ -981,18 +981,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI2_K
)
+
i
/
QI2_K
+
kbxd
]
=
bxi
->
dm
;
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI2_K
)
+
i
/
QI2_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
4
)
{
int
i
=
i0
+
i_offset
*
4
+
k
/
(
WARP_SIZE
/
4
);
int
i
=
i0
+
i_offset
*
4
+
k
/
(
WARP_SIZE
_GGUF
/
4
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
4
))
/
(
QI2_K
/
4
);
x_sc
[
i
*
(
WARP_SIZE
/
4
)
+
i
/
4
+
k
%
(
WARP_SIZE
/
4
)]
=
get_int_from_uint8_aligned
(
bxi
->
scales
,
k
%
(
QI2_K
/
4
));
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
_GGUF
/
4
))
/
(
QI2_K
/
4
);
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
4
)
+
i
/
4
+
k
%
(
WARP_SIZE
_GGUF
/
4
)]
=
get_int_from_uint8_aligned
(
bxi
->
scales
,
k
%
(
QI2_K
/
4
));
}
}
...
...
@@ -1005,7 +1005,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
int
v
[
QR2_K
*
VDR_Q2_K_Q8_1_MMQ
];
const
int
kqsx
=
i
*
(
WARP_SIZE
+
1
)
+
kbx
*
QI2_K
+
(
QI2_K
/
2
)
*
(
ky
/
(
2
*
QI2_K
))
+
ky
%
(
QI2_K
/
2
);
const
int
kqsx
=
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
kbx
*
QI2_K
+
(
QI2_K
/
2
)
*
(
ky
/
(
2
*
QI2_K
))
+
ky
%
(
QI2_K
/
2
);
const
int
shift
=
2
*
((
ky
%
(
2
*
QI2_K
))
/
(
QI2_K
/
2
));
#pragma unroll
...
...
@@ -1013,10 +1013,10 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
v
[
l
]
=
(
x_ql
[
kqsx
+
l
]
>>
shift
)
&
0x03030303
;
}
const
uint8_t
*
scales
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
/
4
)
+
i
/
4
+
kbx
*
4
])
+
ky
/
4
;
const
uint8_t
*
scales
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
4
)
+
i
/
4
+
kbx
*
4
])
+
ky
/
4
;
const
int
index_y
=
j
*
WARP_SIZE
+
(
QR2_K
*
k
)
%
WARP_SIZE
;
return
vec_dot_q2_K_q8_1_impl_mmq
(
v
,
&
y_qs
[
index_y
],
scales
,
x_dm
[
i
*
(
WARP_SIZE
/
QI2_K
)
+
i
/
QI2_K
+
kbx
],
y_df
[
index_y
/
QI8_1
]);
const
int
index_y
=
j
*
WARP_SIZE
_GGUF
+
(
QR2_K
*
k
)
%
WARP_SIZE
_GGUF
;
return
vec_dot_q2_K_q8_1_impl_mmq
(
v
,
&
y_qs
[
index_y
],
scales
,
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI2_K
)
+
i
/
QI2_K
+
kbx
],
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1
(
...
...
@@ -1047,10 +1047,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q3_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI3_K
)
+
mmq_y
/
QI3_K
];
__shared__
int
tile_x_qh
[
mmq_y
*
(
WARP_SIZE
/
2
)
+
mmq_y
/
2
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
4
)
+
mmq_y
/
4
];
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI3_K
)
+
mmq_y
/
QI3_K
];
__shared__
int
tile_x_qh
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
2
)
+
mmq_y
/
2
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
4
)
+
mmq_y
/
4
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
...
...
@@ -1073,10 +1073,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
]
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI3_K
;
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI3_K
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
...
...
@@ -1087,27 +1087,27 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI3_K
)
+
i
/
QI3_K
+
kbxd
]
=
__half2float
(
bxi
->
d
);
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI3_K
)
+
i
/
QI3_K
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
2
)
{
int
i
=
i0
+
i_offset
*
2
+
k
/
(
WARP_SIZE
/
2
);
int
i
=
i0
+
i_offset
*
2
+
k
/
(
WARP_SIZE
_GGUF
/
2
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
2
))
/
(
QI3_K
/
2
);
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
_GGUF
/
2
))
/
(
QI3_K
/
2
);
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
x_qh
[
i
*
(
WARP_SIZE
/
2
)
+
i
/
2
+
k
%
(
WARP_SIZE
/
2
)]
=
~
get_int_from_uint8
(
bxi
->
hmask
,
k
%
(
QI3_K
/
2
));
x_qh
[
i
*
(
WARP_SIZE
_GGUF
/
2
)
+
i
/
2
+
k
%
(
WARP_SIZE
_GGUF
/
2
)]
=
~
get_int_from_uint8
(
bxi
->
hmask
,
k
%
(
QI3_K
/
2
));
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
4
)
{
int
i
=
i0
+
i_offset
*
4
+
k
/
(
WARP_SIZE
/
4
);
int
i
=
i0
+
i_offset
*
4
+
k
/
(
WARP_SIZE
_GGUF
/
4
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
4
))
/
(
QI3_K
/
4
);
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
_GGUF
/
4
))
/
(
QI3_K
/
4
);
const
int
ksc
=
k
%
(
QI3_K
/
4
);
...
...
@@ -1121,7 +1121,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const
int
sc
=
__vsubss4
(
sc_low
|
sc_high
,
0x20202020
);
x_sc
[
i
*
(
WARP_SIZE
/
4
)
+
i
/
4
+
k
%
(
WARP_SIZE
/
4
)]
=
sc
;
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
4
)
+
i
/
4
+
k
%
(
WARP_SIZE
_GGUF
/
4
)]
=
sc
;
}
}
...
...
@@ -1134,24 +1134,24 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
const
int8_t
*
scales
=
((
const
int8_t
*
)
(
x_sc
+
i
*
(
WARP_SIZE
/
4
)
+
i
/
4
+
kbx
*
4
))
+
ky
/
4
;
const
int8_t
*
scales
=
((
const
int8_t
*
)
(
x_sc
+
i
*
(
WARP_SIZE
_GGUF
/
4
)
+
i
/
4
+
kbx
*
4
))
+
ky
/
4
;
int
v
[
QR3_K
*
VDR_Q3_K_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
QR3_K
*
VDR_Q3_K_Q8_1_MMQ
;
++
l
)
{
const
int
kqsx
=
i
*
(
WARP_SIZE
+
1
)
+
kbx
*
QI3_K
+
(
QI3_K
/
2
)
*
(
ky
/
(
2
*
QI3_K
))
+
ky
%
(
QI3_K
/
2
);
const
int
kqsx
=
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
kbx
*
QI3_K
+
(
QI3_K
/
2
)
*
(
ky
/
(
2
*
QI3_K
))
+
ky
%
(
QI3_K
/
2
);
const
int
shift
=
2
*
((
ky
%
32
)
/
8
);
const
int
vll
=
(
x_ql
[
kqsx
+
l
]
>>
shift
)
&
0x03030303
;
const
int
vh
=
x_qh
[
i
*
(
WARP_SIZE
/
2
)
+
i
/
2
+
kbx
*
(
QI3_K
/
2
)
+
(
ky
+
l
)
%
8
]
>>
((
ky
+
l
)
/
8
);
const
int
vh
=
x_qh
[
i
*
(
WARP_SIZE
_GGUF
/
2
)
+
i
/
2
+
kbx
*
(
QI3_K
/
2
)
+
(
ky
+
l
)
%
8
]
>>
((
ky
+
l
)
/
8
);
const
int
vlh
=
(
vh
<<
2
)
&
0x04040404
;
v
[
l
]
=
__vsubss4
(
vll
,
vlh
);
}
const
int
index_y
=
j
*
WARP_SIZE
+
(
k
*
QR3_K
)
%
WARP_SIZE
;
return
vec_dot_q3_K_q8_1_impl_mmq
(
v
,
&
y_qs
[
index_y
],
scales
,
x_dmf
[
i
*
(
WARP_SIZE
/
QI3_K
)
+
i
/
QI3_K
+
kbx
],
y_df
[
index_y
/
QI8_1
]);
const
int
index_y
=
j
*
WARP_SIZE
_GGUF
+
(
k
*
QR3_K
)
%
WARP_SIZE
_GGUF
;
return
vec_dot_q3_K_q8_1_impl_mmq
(
v
,
&
y_qs
[
index_y
],
scales
,
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI3_K
)
+
i
/
QI3_K
+
kbx
],
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1
(
...
...
@@ -1200,9 +1200,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI4_K
)
+
mmq_y
/
QI4_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
8
)
+
mmq_y
/
8
];
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI4_K
)
+
mmq_y
/
QI4_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
...
...
@@ -1225,10 +1225,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI4_K
;
// == 1 if QK_K == 256
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI4_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
#pragma unroll
...
...
@@ -1238,27 +1238,27 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI4_K
)
+
i
/
QI4_K
+
kbxd
]
=
bxi
->
dm
;
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI4_K
)
+
i
/
QI4_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
/
8
))
%
mmq_y
;
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
_GGUF
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
8
))
/
(
QI4_K
/
8
);
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
_GGUF
/
8
))
/
(
QI4_K
/
8
);
const
int
*
scales
=
(
const
int
*
)
bxi
->
scales
;
const
int
ksc
=
k
%
(
WARP_SIZE
/
8
);
const
int
ksc
=
k
%
(
WARP_SIZE
_GGUF
/
8
);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int
scales8
=
(
scales
[(
ksc
%
2
)
+
(
ksc
!=
0
)]
>>
(
4
*
(
ksc
&
(
ksc
/
2
))))
&
0x0F0F0F0F
;
// lower 4 bits
scales8
|=
(
scales
[
ksc
/
2
]
>>
(
2
*
(
ksc
%
2
)))
&
0x30303030
;
// upper 2 bits
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
ksc
]
=
scales8
;
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
8
)
+
i
/
8
+
ksc
]
=
scales8
;
}
}
...
...
@@ -1267,11 +1267,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
(
void
)
x_qh
;
const
uint8_t
*
sc
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
k
/
16
])
+
2
*
((
k
%
16
)
/
8
);
const
uint8_t
*
sc
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
8
)
+
i
/
8
+
k
/
16
])
+
2
*
((
k
%
16
)
/
8
);
const
int
index_y
=
j
*
WARP_SIZE
+
(
QR4_K
*
k
)
%
WARP_SIZE
;
return
vec_dot_q4_K_q8_1_impl_mmq
(
&
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
],
&
y_qs
[
index_y
],
sc
,
sc
+
8
,
x_dm
[
i
*
(
WARP_SIZE
/
QI4_K
)
+
i
/
QI4_K
],
&
y_ds
[
index_y
/
QI8_1
]);
const
int
index_y
=
j
*
WARP_SIZE
_GGUF
+
(
QR4_K
*
k
)
%
WARP_SIZE
_GGUF
;
return
vec_dot_q4_K_q8_1_impl_mmq
(
&
x_ql
[
i
*
(
WARP_SIZE
_GGUF
+
1
)
+
k
],
&
y_qs
[
index_y
],
sc
,
sc
+
8
,
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI4_K
)
+
i
/
QI4_K
],
&
y_ds
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1
(
...
...
@@ -1321,9 +1321,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI5_K
)
+
mmq_y
/
QI5_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
8
)
+
mmq_y
/
8
];
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI5_K
)
+
mmq_y
/
QI5_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
...
...
@@ -1360,11 +1360,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const
int
kq0
=
ky
-
ky
%
(
QI5_K
/
2
)
+
k
%
(
QI5_K
/
4
)
+
0
;
const
int
kq1
=
ky
-
ky
%
(
QI5_K
/
2
)
+
k
%
(
QI5_K
/
4
)
+
(
QI5_K
/
4
);
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
kq0
]
=
ql0
|
qh0
;
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
kq1
]
=
ql1
|
qh1
;
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
kq0
]
=
ql0
|
qh0
;
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
kq1
]
=
ql1
|
qh1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI5_K
;
// == 1 if QK_K == 256
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI5_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
#pragma unroll
...
...
@@ -1376,40 +1376,40 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI5_K
)
+
i
/
QI5_K
+
kbxd
]
=
bxi
->
dm
;
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI5_K
)
+
i
/
QI5_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
/
8
))
%
mmq_y
;
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
_GGUF
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
8
))
/
(
QI5_K
/
8
);
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
_GGUF
/
8
))
/
(
QI5_K
/
8
);
const
int
*
scales
=
(
const
int
*
)
bxi
->
scales
;
const
int
ksc
=
k
%
(
WARP_SIZE
/
8
);
const
int
ksc
=
k
%
(
WARP_SIZE
_GGUF
/
8
);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int
scales8
=
(
scales
[(
ksc
%
2
)
+
(
ksc
!=
0
)]
>>
(
4
*
(
ksc
&
(
ksc
/
2
))))
&
0x0F0F0F0F
;
// lower 4 bits
scales8
|=
(
scales
[
ksc
/
2
]
>>
(
2
*
(
ksc
%
2
)))
&
0x30303030
;
// upper 2 bits
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
ksc
]
=
scales8
;
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
8
)
+
i
/
8
+
ksc
]
=
scales8
;
}
}
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
uint8_t
*
sc
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
k
/
16
])
+
2
*
((
k
%
16
)
/
8
);
const
uint8_t
*
sc
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
8
)
+
i
/
8
+
k
/
16
])
+
2
*
((
k
%
16
)
/
8
);
const
int
index_x
=
i
*
(
QR5_K
*
WARP_SIZE
+
1
)
+
QR5_K
*
k
;
const
int
index_y
=
j
*
WARP_SIZE
+
(
QR5_K
*
k
)
%
WARP_SIZE
;
const
int
index_x
=
i
*
(
QR5_K
*
WARP_SIZE
_GGUF
+
1
)
+
QR5_K
*
k
;
const
int
index_y
=
j
*
WARP_SIZE
_GGUF
+
(
QR5_K
*
k
)
%
WARP_SIZE
_GGUF
;
return
vec_dot_q5_K_q8_1_impl_mmq
(
&
x_ql
[
index_x
],
&
y_qs
[
index_y
],
sc
,
sc
+
8
,
x_dm
[
i
*
(
WARP_SIZE
/
QI5_K
)
+
i
/
QI5_K
],
&
y_ds
[
index_y
/
QI8_1
]);
x_dm
[
i
*
(
WARP_SIZE
_GGUF
/
QI5_K
)
+
i
/
QI5_K
],
&
y_ds
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1
(
...
...
@@ -1439,9 +1439,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q6_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI6_K
)
+
mmq_y
/
QI6_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
8
)
+
mmq_y
/
8
];
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
QI6_K
)
+
mmq_y
/
QI6_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
_GGUF
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
...
...
@@ -1478,11 +1478,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const
int
kq0
=
ky
-
ky
%
QI6_K
+
k
%
(
QI6_K
/
2
)
+
0
;
const
int
kq1
=
ky
-
ky
%
QI6_K
+
k
%
(
QI6_K
/
2
)
+
(
QI6_K
/
2
);
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
kq0
]
=
__vsubss4
(
ql0
|
qh0
,
0x20202020
);
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
kq1
]
=
__vsubss4
(
ql1
|
qh1
,
0x20202020
);
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
kq0
]
=
__vsubss4
(
ql0
|
qh0
,
0x20202020
);
x_ql
[
i
*
(
2
*
WARP_SIZE
_GGUF
+
1
)
+
kq1
]
=
__vsubss4
(
ql1
|
qh1
,
0x20202020
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI6_K
;
// == 1 if QK_K == 256
const
int
blocks_per_tile_x_row
=
WARP_SIZE
_GGUF
/
QI6_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
float
*
x_dmf
=
(
float
*
)
x_dm
;
...
...
@@ -1496,20 +1496,20 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI6_K
)
+
i
/
QI6_K
+
kbxd
]
=
__half2float
(
bxi
->
d
);
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI6_K
)
+
i
/
QI6_K
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
/
8
))
%
mmq_y
;
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
_GGUF
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
8
))
/
4
;
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
_GGUF
/
8
))
/
4
;
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
k
%
(
WARP_SIZE
/
8
)]
=
get_int_from_int8
(
bxi
->
scales
,
k
%
(
QI6_K
/
8
));
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
8
)
+
i
/
8
+
k
%
(
WARP_SIZE
_GGUF
/
8
)]
=
get_int_from_int8
(
bxi
->
scales
,
k
%
(
QI6_K
/
8
));
}
}
...
...
@@ -1519,11 +1519,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
const
int8_t
*
sc
=
((
const
int8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
k
/
8
]);
const
int8_t
*
sc
=
((
const
int8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
_GGUF
/
8
)
+
i
/
8
+
k
/
8
]);
const
int
index_x
=
i
*
(
QR6_K
*
WARP_SIZE
+
1
)
+
QR6_K
*
k
;
const
int
index_y
=
j
*
WARP_SIZE
+
(
QR6_K
*
k
)
%
WARP_SIZE
;
return
vec_dot_q6_K_q8_1_impl_mmq
(
&
x_ql
[
index_x
],
&
y_qs
[
index_y
],
sc
,
x_dmf
[
i
*
(
WARP_SIZE
/
QI6_K
)
+
i
/
QI6_K
],
&
y_df
[
index_y
/
QI8_1
]);
const
int
index_x
=
i
*
(
QR6_K
*
WARP_SIZE
_GGUF
+
1
)
+
QR6_K
*
k
;
const
int
index_y
=
j
*
WARP_SIZE
_GGUF
+
(
QR6_K
*
k
)
%
WARP_SIZE
_GGUF
;
return
vec_dot_q6_K_q8_1_impl_mmq
(
&
x_ql
[
index_x
],
&
y_qs
[
index_y
],
sc
,
x_dmf
[
i
*
(
WARP_SIZE
_GGUF
/
QI6_K
)
+
i
/
QI6_K
],
&
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_iq2_xxs_q8_1
(
...
...
@@ -1582,7 +1582,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
static
__device__
__forceinline__
float
vec_dot_iq2_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
const
block_iq2_s
*
bq2
=
(
const
block_iq2_s
*
)
vbq
;
const
int
ib32
=
iqs
;
...
...
@@ -1619,7 +1619,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
static
__device__
__forceinline__
float
vec_dot_iq3_xxs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
const
block_iq3_xxs
*
bq2
=
(
const
block_iq3_xxs
*
)
vbq
;
const
int
ib32
=
iqs
;
...
...
@@ -1646,7 +1646,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
static
__device__
__forceinline__
float
vec_dot_iq3_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
const
block_iq3_s
*
bq2
=
(
const
block_iq3_s
*
)
vbq
;
const
int
ib32
=
iqs
;
...
...
@@ -1671,7 +1671,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
static
__device__
__forceinline__
float
vec_dot_iq1_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
const
block_iq1_s
*
bq1
=
(
const
block_iq1_s
*
)
vbq
;
const
int
qs_packed
=
get_int_b2
(
bq1
->
qs
,
iqs
);
...
...
@@ -1703,7 +1703,7 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
static
__device__
__forceinline__
float
vec_dot_iq1_m_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
const
block_iq1_m
*
bq1
=
(
const
block_iq1_m
*
)
vbq
;
...
...
@@ -1763,7 +1763,7 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4
static
__device__
__forceinline__
float
vec_dot_iq4_nl_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
const
block_iq4_nl
*
bq
=
(
const
block_iq4_nl
*
)
vbq
;
...
...
@@ -1788,7 +1788,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
static
__device__
__forceinline__
float
vec_dot_iq4_xs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
|| defined USE_ROCM
const
block_iq4_xs
*
bq4
=
(
const
block_iq4_xs
*
)
vbq
;
const
uint8_t
*
values
=
(
const
uint8_t
*
)
kvalues_iq4nl
;
...
...
csrc/torch_bindings.cpp
View file @
7c25fe45
...
...
@@ -258,6 +258,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"
);
// conditionally compiled so impl registrations are in source file
#endif
// Dequantization for GGML.
ops
.
def
(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"
);
...
...
@@ -274,6 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"
);
ops
.
impl
(
"ggml_mul_mat_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_a8
);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
...
...
vllm/_custom_ops.py
View file @
7c25fe45
...
...
@@ -344,31 +344,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
register_fake
(
"_C::ggml_dequantize"
)
def
_ggml_dequantize_fake
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
torch
.
SymInt
,
n
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_vec_a8"
)
def
_ggml_mul_mat_vec_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
return
torch
.
empty
((
1
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_a8"
)
def
_ggml_mul_mat_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::marlin_qqq_gemm"
)
def
_marlin_qqq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
s_tok
:
torch
.
Tensor
,
s_ch
:
torch
.
Tensor
,
...
...
@@ -468,6 +443,34 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
memory_format
=
torch
.
contiguous_format
)
if
hasattr
(
torch
.
ops
.
_C
,
"ggml_dequantize"
):
@
register_fake
(
"_C::ggml_dequantize"
)
def
_ggml_dequantize_fake
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
torch
.
SymInt
,
n
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_vec_a8"
)
def
_ggml_mul_mat_vec_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
return
torch
.
empty
((
1
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_mul_mat_a8"
)
def
_ggml_mul_mat_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
# cutlass
def
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
)
...
...
vllm/config.py
View file @
7c25fe45
...
...
@@ -387,7 +387,7 @@ class ModelConfig:
supported_quantization
=
QUANTIZATION_METHODS
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
"fbgemm_fp8"
,
"gguf"
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment