Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8fdcd98e
Unverified
Commit
8fdcd98e
authored
Oct 12, 2025
by
PGFLMG
Committed by
GitHub
Oct 11, 2025
Browse files
[7/n] decouple quantization impl from vllm dependency - gguf kernel (#11019)
parent
b5dcfd41
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
7936 additions
and
1 deletion
+7936
-1
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+3
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+37
-0
sgl-kernel/csrc/moe/moe_sum.cu
sgl-kernel/csrc/moe/moe_sum.cu
+66
-0
sgl-kernel/csrc/quantization/gguf/dequantize.cuh
sgl-kernel/csrc/quantization/gguf/dequantize.cuh
+583
-0
sgl-kernel/csrc/quantization/gguf/ggml-common.h
sgl-kernel/csrc/quantization/gguf/ggml-common.h
+1029
-0
sgl-kernel/csrc/quantization/gguf/gguf_kernel.cu
sgl-kernel/csrc/quantization/gguf/gguf_kernel.cu
+836
-0
sgl-kernel/csrc/quantization/gguf/mmq.cuh
sgl-kernel/csrc/quantization/gguf/mmq.cuh
+881
-0
sgl-kernel/csrc/quantization/gguf/mmvq.cuh
sgl-kernel/csrc/quantization/gguf/mmvq.cuh
+352
-0
sgl-kernel/csrc/quantization/gguf/moe.cuh
sgl-kernel/csrc/quantization/gguf/moe.cuh
+1379
-0
sgl-kernel/csrc/quantization/gguf/moe_vec.cuh
sgl-kernel/csrc/quantization/gguf/moe_vec.cuh
+413
-0
sgl-kernel/csrc/quantization/gguf/vecdotq.cuh
sgl-kernel/csrc/quantization/gguf/vecdotq.cuh
+2037
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+28
-0
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+20
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+9
-0
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+10
-0
sgl-kernel/python/sgl_kernel/quantization/__init__.py
sgl-kernel/python/sgl_kernel/quantization/__init__.py
+8
-0
sgl-kernel/python/sgl_kernel/quantization/gguf.py
sgl-kernel/python/sgl_kernel/quantization/gguf.py
+62
-0
sgl-kernel/tests/test_gguf.py
sgl-kernel/tests/test_gguf.py
+160
-0
sgl-kernel/tests/test_moe_align.py
sgl-kernel/tests/test_moe_align.py
+23
-1
No files found.
sgl-kernel/CMakeLists.txt
View file @
8fdcd98e
...
...
@@ -271,6 +271,8 @@ set(SOURCES
"csrc/elementwise/topk.cu"
"csrc/common_extension.cc"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
...
...
@@ -306,6 +308,7 @@ set(SOURCES
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_sum.cu"
"csrc/moe/moe_sum_reduce.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
8fdcd98e
...
...
@@ -114,6 +114,37 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"cu_seqlens_q) -> ()"
);
m
.
impl
(
"fast_topk_transform_fused"
,
torch
::
kCUDA
,
&
fast_topk_transform_interface
);
/*
* From gguf quantiztion
*/
m
.
def
(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor"
);
m
.
impl
(
"ggml_dequantize"
,
torch
::
kCUDA
,
&
ggml_dequantize
);
m
.
def
(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor"
);
m
.
impl
(
"ggml_mul_mat_vec_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_vec_a8
);
m
.
def
(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"
);
m
.
impl
(
"ggml_mul_mat_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_a8
);
m
.
def
(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"
);
m
.
impl
(
"ggml_moe_a8"
,
torch
::
kCUDA
,
&
ggml_moe_a8
);
m
.
def
(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor"
);
m
.
impl
(
"ggml_moe_a8_vec"
,
torch
::
kCUDA
,
&
ggml_moe_a8_vec
);
m
.
def
(
"ggml_moe_get_block_size"
,
&
ggml_moe_get_block_size
);
/*
* From csrc/gemm
*/
...
...
@@ -226,17 +257,23 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()"
);
m
.
impl
(
"moe_sum_reduce"
,
torch
::
kCUDA
,
&
moe_sum_reduce
);
m
.
def
(
"moe_sum(Tensor input, Tensor! output) -> ()"
);
m
.
impl
(
"moe_sum"
,
torch
::
kCUDA
,
&
moe_sum
);
m
.
def
(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
def
(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
"expert_offsets, Tensor workspace) -> ()"
);
m
.
impl
(
"fp8_blockwise_scaled_grouped_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_grouped_mm
);
m
.
def
(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
...
...
sgl-kernel/csrc/moe/moe_sum.cu
0 → 100644
View file @
8fdcd98e
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <ATen/cuda/Atomic.cuh>
#include <cub/cub.cuh>
#include "utils.h"
template
<
typename
scalar_t
,
int
TOPK
>
__global__
void
moe_sum_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., topk, d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
scalar_t
x
=
0.0
;
#pragma unroll
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
x
+=
SGLANG_LDG
(
&
input
[
token_idx
*
TOPK
*
d
+
k
*
d
+
idx
]);
}
out
[
token_idx
*
d
+
idx
]
=
x
;
}
}
void
moe_sum
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
torch
::
Tensor
&
output
)
// [num_tokens, hidden_size]
{
const
int
hidden_size
=
input
.
size
(
-
1
);
const
auto
num_tokens
=
output
.
numel
()
/
hidden_size
;
const
int
topk
=
input
.
size
(
1
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
topk
)
{
case
2
:
DISPATCH_FLOAT_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
moe_sum_kernel
<
scalar_t
,
2
>
<<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
case
3
:
DISPATCH_FLOAT_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
moe_sum_kernel
<
scalar_t
,
3
>
<<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
case
4
:
DISPATCH_FLOAT_TYPES
(
input
.
scalar_type
(),
"moe_sum_kernel"
,
[
&
]
{
moe_sum_kernel
<
scalar_t
,
4
>
<<<
grid
,
block
,
0
,
stream
>>>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
hidden_size
);
});
break
;
default:
at
::
sum_out
(
output
,
input
,
1
);
break
;
}
}
sgl-kernel/csrc/quantization/gguf/dequantize.cuh
0 → 100644
View file @
8fdcd98e
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/dequantize.cuh
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/convert.cu
// Dequant functions
static
__device__
__forceinline__
void
dequantize_q4_0
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
)
{
const
block_q4_0
*
x
=
(
const
block_q4_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
const
int
vui
=
x
[
ib
].
qs
[
iqs
];
v
.
x
=
__int2half_rn
(
vui
&
0xF
);
v
.
y
=
__int2half_rn
(
vui
>>
4
);
v
=
__hsub2
(
v
,
__floats2half2_rn
(
8.0
f
,
8.0
f
));
v
=
__hmul2
(
v
,
{
d
,
d
});
}
static
__device__
__forceinline__
void
dequantize_q4_1
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
)
{
const
block_q4_1
*
x
=
(
const
block_q4_1
*
)
vx
;
const
dfloat
d
=
__low2half
(
x
[
ib
].
dm
);
const
dfloat
m
=
__high2half
(
x
[
ib
].
dm
);
const
int
vui
=
x
[
ib
].
qs
[
iqs
];
v
.
x
=
__int2half_rn
(
vui
&
0xF
);
v
.
y
=
__int2half_rn
(
vui
>>
4
);
v
=
__hmul2
(
v
,
{
d
,
d
});
v
=
__hadd2
(
v
,
{
m
,
m
});
}
static
__device__
__forceinline__
void
dequantize_q5_0
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
)
{
const
block_q5_0
*
x
=
(
const
block_q5_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
uint32_t
qh
;
memcpy
(
&
qh
,
x
[
ib
].
qh
,
sizeof
(
qh
));
const
int
xh_0
=
((
qh
>>
(
iqs
+
0
))
<<
4
)
&
0x10
;
const
int
xh_1
=
((
qh
>>
(
iqs
+
12
)))
&
0x10
;
v
.
x
=
__int2half_rn
((
x
[
ib
].
qs
[
iqs
]
&
0xf
)
|
xh_0
);
v
.
y
=
__int2half_rn
((
x
[
ib
].
qs
[
iqs
]
>>
4
)
|
xh_1
);
v
=
__hsub2
(
v
,
__floats2half2_rn
(
16.0
f
,
16.0
f
));
v
=
__hmul2
(
v
,
{
d
,
d
});
}
static
__device__
__forceinline__
void
dequantize_q5_1
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
)
{
const
block_q5_1
*
x
=
(
const
block_q5_1
*
)
vx
;
const
dfloat
d
=
__low2half
(
x
[
ib
].
dm
);
const
dfloat
m
=
__high2half
(
x
[
ib
].
dm
);
uint32_t
qh
;
memcpy
(
&
qh
,
x
[
ib
].
qh
,
sizeof
(
qh
));
const
int
xh_0
=
((
qh
>>
(
iqs
+
0
))
<<
4
)
&
0x10
;
const
int
xh_1
=
((
qh
>>
(
iqs
+
12
)))
&
0x10
;
v
.
x
=
__int2half_rn
((
x
[
ib
].
qs
[
iqs
]
&
0xf
)
|
xh_0
);
v
.
y
=
__int2half_rn
((
x
[
ib
].
qs
[
iqs
]
>>
4
)
|
xh_1
);
v
=
__hmul2
(
v
,
{
d
,
d
});
v
=
__hadd2
(
v
,
{
m
,
m
});
}
static
__device__
__forceinline__
void
dequantize_q8_0
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
)
{
const
block_q8_0
*
x
=
(
const
block_q8_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
v
.
x
=
__int2half_rn
(
x
[
ib
].
qs
[
iqs
+
0
]);
v
.
y
=
__int2half_rn
(
x
[
ib
].
qs
[
iqs
+
1
]);
v
=
__hmul2
(
v
,
{
d
,
d
});
}
template
<
int
qk
,
int
qr
,
dequantize_kernel_t
dequantize_kernel
,
typename
dst_t
>
static
__global__
void
dequantize_block
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
y
,
const
int
k
)
{
const
int
i
=
2
*
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
);
if
(
i
>=
k
)
{
return
;
}
const
int
ib
=
i
/
qk
;
// block index
const
int
iqs
=
(
i
%
qk
)
/
qr
;
// quant index
const
int
iybs
=
i
-
i
%
qk
;
// y block start index
const
int
y_offset
=
qr
==
1
?
1
:
qk
/
2
;
// dequantize
dfloat2
v
;
dequantize_kernel
(
vx
,
ib
,
iqs
,
v
);
y
[
iybs
+
iqs
+
0
]
=
convert_from_half
<
dst_t
>
(
v
.
x
);
y
[
iybs
+
iqs
+
y_offset
]
=
convert_from_half
<
dst_t
>
(
v
.
y
);
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q2_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_q2_K
*
x
=
(
const
block_q2_K
*
)
vx
;
const
auto
tid
=
threadIdx
.
x
;
const
int
n
=
tid
/
32
;
const
int
l
=
tid
-
32
*
n
;
const
int
is
=
8
*
n
+
l
/
16
;
const
uint8_t
q
=
x
[
i
].
qs
[
32
*
n
+
l
];
dst_t
*
y
=
yy
+
i
*
QK_K
+
128
*
n
;
half
dall
=
__low2half
(
x
[
i
].
dm
);
half
dmin
=
__high2half
(
x
[
i
].
dm
);
y
[
l
+
0
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
0
]
&
0xF
)
*
((
q
>>
0
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
0
]
>>
4
))));
y
[
l
+
32
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
2
]
&
0xF
)
*
((
q
>>
2
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
2
]
>>
4
))));
y
[
l
+
64
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
4
]
&
0xF
)
*
((
q
>>
4
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
4
]
>>
4
))));
y
[
l
+
96
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
6
]
&
0xF
)
*
((
q
>>
6
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
6
]
>>
4
))));
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q3_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_q3_K
*
x
=
(
const
block_q3_K
*
)
vx
;
const
auto
r
=
threadIdx
.
x
/
4
;
const
int
tid
=
r
/
2
;
const
int
is0
=
r
%
2
;
const
int
l0
=
16
*
is0
+
4
*
(
threadIdx
.
x
%
4
);
const
int
n
=
tid
/
4
;
const
int
j
=
tid
-
4
*
n
;
uint8_t
m
=
1
<<
(
4
*
n
+
j
);
int
is
=
8
*
n
+
2
*
j
+
is0
;
int
shift
=
2
*
j
;
int8_t
us
=
is
<
4
?
(
x
[
i
].
scales
[
is
-
0
]
&
0xF
)
|
(((
x
[
i
].
scales
[
is
+
8
]
>>
0
)
&
3
)
<<
4
)
:
is
<
8
?
(
x
[
i
].
scales
[
is
-
0
]
&
0xF
)
|
(((
x
[
i
].
scales
[
is
+
4
]
>>
2
)
&
3
)
<<
4
)
:
is
<
12
?
(
x
[
i
].
scales
[
is
-
8
]
>>
4
)
|
(((
x
[
i
].
scales
[
is
+
0
]
>>
4
)
&
3
)
<<
4
)
:
(
x
[
i
].
scales
[
is
-
8
]
>>
4
)
|
(((
x
[
i
].
scales
[
is
-
4
]
>>
6
)
&
3
)
<<
4
);
half
d_all
=
x
[
i
].
d
;
half
dl
=
__hmul
(
d_all
,
__int2half_rn
(
us
-
32
));
dst_t
*
y
=
yy
+
i
*
QK_K
+
128
*
n
+
32
*
j
;
const
uint8_t
*
q
=
x
[
i
].
qs
+
32
*
n
;
const
uint8_t
*
hm
=
x
[
i
].
hmask
;
for
(
int
l
=
l0
;
l
<
l0
+
4
;
++
l
)
{
y
[
l
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
dl
,
__int2half_rn
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
)
-
((
hm
[
l
]
&
m
)
?
0
:
4
))));
}
}
static
inline
__device__
void
get_scale_min_k4
(
int
j
,
const
uint8_t
*
q
,
uint8_t
&
d
,
uint8_t
&
m
)
{
if
(
j
<
4
)
{
d
=
q
[
j
]
&
63
;
m
=
q
[
j
+
4
]
&
63
;
}
else
{
d
=
(
q
[
j
+
4
]
&
0xF
)
|
((
q
[
j
-
4
]
>>
6
)
<<
4
);
m
=
(
q
[
j
+
4
]
>>
4
)
|
((
q
[
j
-
0
]
>>
6
)
<<
4
);
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q4_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
block_q4_K
*
x
=
(
const
block_q4_K
*
)
vx
;
const
auto
i
=
blockIdx
.
x
;
// assume 32 threads
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
const
int
ir
=
tid
%
8
;
const
int
is
=
2
*
il
;
const
int
n
=
4
;
dst_t
*
y
=
yy
+
i
*
QK_K
+
64
*
il
+
n
*
ir
;
const
half
dall
=
__low2half
(
x
[
i
].
dm
);
const
half
dmin
=
__high2half
(
x
[
i
].
dm
);
const
uint8_t
*
q
=
x
[
i
].
qs
+
32
*
il
+
n
*
ir
;
uint8_t
sc
,
m
;
get_scale_min_k4
(
is
+
0
,
x
[
i
].
scales
,
sc
,
m
);
const
half
d1
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m1
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
get_scale_min_k4
(
is
+
1
,
x
[
i
].
scales
,
sc
,
m
);
const
half
d2
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m2
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
for
(
int
l
=
0
;
l
<
n
;
++
l
)
{
y
[
l
+
0
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d1
,
__int2half_rn
(
q
[
l
]
&
0xF
)),
m1
));
y
[
l
+
32
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d2
,
__int2half_rn
(
q
[
l
]
>>
4
)),
m2
));
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q5_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
block_q5_K
*
x
=
(
const
block_q5_K
*
)
vx
;
const
auto
i
=
blockIdx
.
x
;
// assume 64 threads - this is very slightly better than the one below
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
16
;
// il is in 0...3
const
int
ir
=
tid
%
16
;
// ir is in 0...15
const
int
is
=
2
*
il
;
// is is in 0...6
dst_t
*
y
=
yy
+
i
*
QK_K
+
64
*
il
+
2
*
ir
;
const
half
dall
=
__low2half
(
x
[
i
].
dm
);
const
half
dmin
=
__high2half
(
x
[
i
].
dm
);
const
uint8_t
*
ql
=
x
[
i
].
qs
+
32
*
il
+
2
*
ir
;
const
uint8_t
*
qh
=
x
[
i
].
qh
+
2
*
ir
;
uint8_t
sc
,
m
;
get_scale_min_k4
(
is
+
0
,
x
[
i
].
scales
,
sc
,
m
);
const
half
d1
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m1
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
get_scale_min_k4
(
is
+
1
,
x
[
i
].
scales
,
sc
,
m
);
const
half
d2
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m2
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
uint8_t
hm
=
1
<<
(
2
*
il
);
y
[
0
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d1
,
__int2half_rn
((
ql
[
0
]
&
0xF
)
+
(
qh
[
0
]
&
hm
?
16
:
0
))),
m1
));
y
[
1
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d1
,
__int2half_rn
((
ql
[
1
]
&
0xF
)
+
(
qh
[
1
]
&
hm
?
16
:
0
))),
m1
));
hm
<<=
1
;
y
[
32
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d2
,
__int2half_rn
((
ql
[
0
]
>>
4
)
+
(
qh
[
0
]
&
hm
?
16
:
0
))),
m2
));
y
[
33
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d2
,
__int2half_rn
((
ql
[
1
]
>>
4
)
+
(
qh
[
1
]
&
hm
?
16
:
0
))),
m2
));
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q6_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
block_q6_K
*
x
=
(
const
block_q6_K
*
)
vx
;
const
auto
i
=
blockIdx
.
x
;
// assume 64 threads - this is very slightly better than the one below
const
auto
tid
=
threadIdx
.
x
;
const
int
ip
=
tid
/
32
;
// ip is 0 or 1
const
int
il
=
tid
-
32
*
ip
;
// 0...32
const
int
is
=
8
*
ip
+
il
/
16
;
dst_t
*
y
=
yy
+
i
*
QK_K
+
128
*
ip
+
il
;
const
half
d
=
x
[
i
].
d
;
const
uint8_t
*
ql
=
x
[
i
].
ql
+
64
*
ip
+
il
;
const
uint8_t
qh
=
x
[
i
].
qh
[
32
*
ip
+
il
];
const
int8_t
*
sc
=
x
[
i
].
scales
+
is
;
y
[
0
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
d
,
__int2half_rn
(
sc
[
0
]
*
((
int8_t
)((
ql
[
0
]
&
0xF
)
|
(((
qh
>>
0
)
&
3
)
<<
4
))
-
32
))));
y
[
32
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
d
,
__int2half_rn
(
sc
[
2
]
*
((
int8_t
)((
ql
[
32
]
&
0xF
)
|
(((
qh
>>
2
)
&
3
)
<<
4
))
-
32
))));
y
[
64
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
d
,
__int2half_rn
(
sc
[
4
]
*
((
int8_t
)((
ql
[
0
]
>>
4
)
|
(((
qh
>>
4
)
&
3
)
<<
4
))
-
32
))));
y
[
96
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
d
,
__int2half_rn
(
sc
[
6
]
*
((
int8_t
)((
ql
[
32
]
>>
4
)
|
(((
qh
>>
6
)
&
3
)
<<
4
))
-
32
))));
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq2_xxs
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_iq2_xxs
*
x
=
(
const
block_iq2_xxs
*
)
vx
;
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint16_t
*
q2
=
x
[
i
].
qs
+
4
*
ib
;
const
uint8_t
*
aux8
=
(
const
uint8_t
*
)
q2
;
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xxs_grid
+
aux8
[
il
]);
const
uint32_t
aux32
=
q2
[
2
]
|
(
q2
[
3
]
<<
16
);
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
(
aux32
>>
28
))
*
0.25
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[(
aux32
>>
7
*
il
)
&
127
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
);
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq2_xs
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_iq2_xs
*
x
=
(
const
block_iq2_xs
*
)
vx
;
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint16_t
*
q2
=
x
[
i
].
qs
+
4
*
ib
;
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xs_grid
+
(
q2
[
il
]
&
511
));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
]
>>
4
*
(
il
/
2
))
&
0xf
))
*
0.25
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[
q2
[
il
]
>>
9
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
);
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq2_s
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_iq2_s
*
x
=
(
const
block_iq2_s
*
)
vx
;
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2s_grid
+
(
x
[
i
].
qs
[
4
*
ib
+
il
]
|
((
x
[
i
].
qh
[
ib
]
<<
(
8
-
2
*
il
))
&
0x300
)));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
]
>>
4
*
(
il
/
2
))
&
0xf
))
*
0.25
f
;
const
uint8_t
signs
=
x
[
i
].
qs
[
QK_K
/
8
+
4
*
ib
+
il
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
);
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq3_xxs
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_iq3_xxs
*
x
=
(
const
block_iq3_xxs
*
)
vx
;
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint8_t
*
q3
=
x
[
i
].
qs
+
8
*
ib
;
const
uint16_t
*
gas
=
(
const
uint16_t
*
)(
x
[
i
].
qs
+
QK_K
/
4
)
+
2
*
ib
;
const
uint8_t
*
grid1
=
(
const
uint8_t
*
)(
iq3xxs_grid
+
q3
[
2
*
il
+
0
]);
const
uint8_t
*
grid2
=
(
const
uint8_t
*
)(
iq3xxs_grid
+
q3
[
2
*
il
+
1
]);
const
uint32_t
aux32
=
gas
[
0
]
|
(
gas
[
1
]
<<
16
);
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
(
aux32
>>
28
))
*
0.5
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[(
aux32
>>
7
*
il
)
&
127
];
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
d
*
grid1
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
0
]
?
-
1.
f
:
1.
f
);
y
[
j
+
4
]
=
d
*
grid2
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
4
]
?
-
1.
f
:
1.
f
);
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq3_s
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_iq3_s
*
x
=
(
const
block_iq3_s
*
)
vx
;
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint8_t
*
qs
=
x
[
i
].
qs
+
8
*
ib
;
const
uint8_t
*
grid1
=
(
const
uint8_t
*
)(
iq3xs_grid
+
(
qs
[
2
*
il
+
0
]
|
((
x
[
i
].
qh
[
ib
]
<<
(
8
-
2
*
il
))
&
256
)));
const
uint8_t
*
grid2
=
(
const
uint8_t
*
)(
iq3xs_grid
+
(
qs
[
2
*
il
+
1
]
|
((
x
[
i
].
qh
[
ib
]
<<
(
7
-
2
*
il
))
&
256
)));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
))
*
0.5
f
;
const
uint8_t
signs
=
x
[
i
].
signs
[
4
*
ib
+
il
];
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
d
*
grid1
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
0
]
?
-
1.
f
:
1.
f
);
y
[
j
+
4
]
=
d
*
grid2
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
4
]
?
-
1.
f
:
1.
f
);
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq1_s
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int64_t
i
=
blockIdx
.
x
;
const
block_iq1_s
*
x
=
(
const
block_iq1_s
*
)
vx
;
const
int64_t
tid
=
threadIdx
.
x
;
const
int64_t
il
=
tid
/
8
;
// 0...3
const
int64_t
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
float
delta
=
x
[
i
].
qh
[
ib
]
&
0x8000
?
-
1
-
IQ1S_DELTA
:
-
1
+
IQ1S_DELTA
;
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
2
*
((
x
[
i
].
qh
[
ib
]
>>
12
)
&
7
)
+
1
);
uint32_t
grid32
[
2
];
const
int8_t
*
q
=
(
const
int8_t
*
)
grid32
;
grid32
[
0
]
=
iq1s_grid_gpu
[
x
[
i
].
qs
[
4
*
ib
+
il
]
|
(((
x
[
i
].
qh
[
ib
]
>>
3
*
il
)
&
7
)
<<
8
)];
grid32
[
1
]
=
(
grid32
[
0
]
>>
4
)
&
0x0f0f0f0f
;
grid32
[
0
]
&=
0x0f0f0f0f
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
y
[
j
]
=
d
*
(
q
[
j
]
+
delta
);
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq1_m
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int64_t
i
=
blockIdx
.
x
;
const
block_iq1_m
*
x
=
(
const
block_iq1_m
*
)
vx
;
const
int64_t
tid
=
threadIdx
.
x
;
const
int64_t
il
=
tid
/
8
;
// 0...3
const
int64_t
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint16_t
*
sc
=
(
const
uint16_t
*
)
x
[
i
].
scales
;
iq1m_scale_t
scale
;
scale
.
u16
=
(
sc
[
0
]
>>
12
)
|
((
sc
[
1
]
>>
8
)
&
0x00f0
)
|
((
sc
[
2
]
>>
4
)
&
0x0f00
)
|
(
sc
[
3
]
&
0xf000
);
const
int64_t
ib16
=
2
*
ib
+
il
/
2
;
// sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
const
float
d
=
__half2float
(
scale
.
f16
)
*
(
2
*
((
sc
[
ib16
/
4
]
>>
3
*
(
ib16
%
4
))
&
0x7
)
+
1
);
const
float
delta
=
x
[
i
].
qh
[
2
*
ib
+
il
/
2
]
&
(
0x08
<<
4
*
(
il
%
2
))
?
-
1
-
IQ1M_DELTA
:
-
1
+
IQ1M_DELTA
;
uint32_t
grid32
[
2
];
const
int8_t
*
q
=
(
const
int8_t
*
)
grid32
;
grid32
[
0
]
=
iq1s_grid_gpu
[
x
[
i
].
qs
[
4
*
ib
+
il
]
|
(((
x
[
i
].
qh
[
2
*
ib
+
il
/
2
]
>>
4
*
(
il
%
2
))
&
7
)
<<
8
)];
grid32
[
1
]
=
(
grid32
[
0
]
>>
4
)
&
0x0f0f0f0f
;
grid32
[
0
]
&=
0x0f0f0f0f
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
y
[
j
]
=
d
*
(
q
[
j
]
+
delta
);
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq4_nl
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_iq4_nl
*
x
=
(
const
block_iq4_nl
*
)
vx
+
i
*
(
QK_K
/
QK4_NL
);
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
4
*
il
;
const
uint8_t
*
q4
=
x
[
ib
].
qs
+
4
*
il
;
const
float
d
=
__half2float
(
x
[
ib
].
d
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
d
*
kvalues_iq4nl
[
q4
[
j
]
&
0xf
];
y
[
j
+
16
]
=
d
*
kvalues_iq4nl
[
q4
[
j
]
>>
4
];
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq4_xs
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
auto
i
=
blockIdx
.
x
;
const
block_iq4_xs
*
x
=
(
const
block_iq4_xs
*
)
vx
;
const
auto
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
4
*
il
;
const
uint8_t
*
q4
=
x
[
i
].
qs
+
16
*
ib
+
4
*
il
;
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
((((
x
[
i
].
scales_l
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
)
|
(((
x
[
i
].
scales_h
>>
2
*
ib
)
&
3
)
<<
4
))
-
32
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
d
*
kvalues_iq4nl
[
q4
[
j
]
&
0xf
];
y
[
j
+
16
]
=
d
*
kvalues_iq4nl
[
q4
[
j
]
>>
4
];
}
}
template
<
int
qk
,
int
qr
,
dequantize_kernel_t
dequantize_kernel
,
typename
dst_t
>
static
void
dequantize_block_cuda
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
(
k
+
2
*
CUDA_DEQUANTIZE_BLOCK_SIZE
-
1
)
/
(
2
*
CUDA_DEQUANTIZE_BLOCK_SIZE
);
dequantize_block
<
qk
,
qr
,
dequantize_kernel
><<<
num_blocks
,
CUDA_DEQUANTIZE_BLOCK_SIZE
,
0
,
stream
>>>
(
vx
,
y
,
k
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q2_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q2_K
<<<
nb
,
64
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q3_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q3_K
<<<
nb
,
64
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q4_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q4_K
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q5_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q5_K
<<<
nb
,
64
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q6_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q6_K
<<<
nb
,
64
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq2_xxs_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq2_xxs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq2_xs_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq2_xs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq2_s_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq2_s
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq3_xxs_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq3_xxs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq3_s_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq3_s
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq1_s_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq1_s
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq1_m_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq1_m
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq4_nl_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
(
k
+
QK_K
-
1
)
/
QK_K
;
dequantize_block_iq4_nl
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq4_xs_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
(
k
+
QK_K
-
1
)
/
QK_K
;
dequantize_block_iq4_xs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
to_cuda_ggml_t
<
dst_t
>
ggml_get_to_cuda
(
int64_t
type
)
{
switch
(
type
)
{
case
2
:
return
dequantize_block_cuda
<
QK4_0
,
QR4_0
,
dequantize_q4_0
>
;
case
3
:
return
dequantize_block_cuda
<
QK4_1
,
QR4_1
,
dequantize_q4_1
>
;
case
6
:
return
dequantize_block_cuda
<
QK5_0
,
QR5_0
,
dequantize_q5_0
>
;
case
7
:
return
dequantize_block_cuda
<
QK5_1
,
QR5_1
,
dequantize_q5_1
>
;
case
8
:
return
dequantize_block_cuda
<
QK8_0
,
QR8_0
,
dequantize_q8_0
>
;
case
10
:
return
dequantize_row_q2_K_cuda
;
case
11
:
return
dequantize_row_q3_K_cuda
;
case
12
:
return
dequantize_row_q4_K_cuda
;
case
13
:
return
dequantize_row_q5_K_cuda
;
case
14
:
return
dequantize_row_q6_K_cuda
;
case
16
:
return
dequantize_row_iq2_xxs_cuda
;
case
17
:
return
dequantize_row_iq2_xs_cuda
;
case
18
:
return
dequantize_row_iq3_xxs_cuda
;
case
19
:
return
dequantize_row_iq1_s_cuda
;
case
20
:
return
dequantize_row_iq4_nl_cuda
;
case
21
:
return
dequantize_row_iq3_s_cuda
;
case
22
:
return
dequantize_row_iq2_s_cuda
;
case
23
:
return
dequantize_row_iq4_xs_cuda
;
case
29
:
return
dequantize_row_iq1_m_cuda
;
default:
return
nullptr
;
}
}
sgl-kernel/csrc/quantization/gguf/ggml-common.h
0 → 100644
View file @
8fdcd98e
// adapted from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/ggml-common.h
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE_GGUF 32
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define GGML_CUDA_DMMV_X 32
#define GGML_CUDA_MMV_Y 1
// Data Structures
// QK = number of values after dequantization
// QR = QK / number of values before dequantization
// QI = number of 32 bit integers before dequantization
#define QK4_0 32
#define QR4_0 2
#define QI4_0 (QK4_0 / (4 * QR4_0))
typedef
struct
{
half
d
;
// delta
uint8_t
qs
[
QK4_0
/
2
];
// nibbles / quants
}
block_q4_0
;
#define QK4_1 32
#define QR4_1 2
#define QI4_1 (QK4_1 / (4 * QR4_1))
typedef
struct
{
half2
dm
;
// dm.x = delta, dm.y = min
uint8_t
qs
[
QK4_1
/
2
];
// nibbles / quants
}
block_q4_1
;
#define QK5_0 32
#define QR5_0 2
#define QI5_0 (QK5_0 / (4 * QR5_0))
typedef
struct
{
half
d
;
// delta
uint8_t
qh
[
4
];
// 5-th bit of quants
uint8_t
qs
[
QK5_0
/
2
];
// nibbles / quants
}
block_q5_0
;
#define QK5_1 32
#define QR5_1 2
#define QI5_1 (QK5_1 / (4 * QR5_1))
typedef
struct
{
half2
dm
;
// dm.x = delta, dm.y = min
uint8_t
qh
[
4
];
// 5-th bit of quants
uint8_t
qs
[
QK5_1
/
2
];
// nibbles / quants
}
block_q5_1
;
#define QK8_0 32
#define QR8_0 1
#define QI8_0 (QK8_0 / (4 * QR8_0))
typedef
struct
{
half
d
;
// delta
int8_t
qs
[
QK8_0
];
// quants
}
block_q8_0
;
#define QK8_1 32
#define QR8_1 1
#define QI8_1 (QK8_1 / (4 * QR8_1))
typedef
struct
{
half2
ds
;
// ds.x = delta, ds.y = sum
int8_t
qs
[
QK8_0
];
// quants
}
block_q8_1
;
#define QR2_K 4
#define QI2_K (QK_K / (4 * QR2_K))
typedef
struct
{
uint8_t
scales
[
QK_K
/
16
];
// scales and mins, quantized with 4 bits
uint8_t
qs
[
QK_K
/
4
];
// quants
half2
dm
;
// super-block scale for quantized scales/mins
}
block_q2_K
;
#define QR3_K 4
#define QI3_K (QK_K / (4 * QR3_K))
typedef
struct
{
uint8_t
hmask
[
QK_K
/
8
];
// quants - high bit
uint8_t
qs
[
QK_K
/
4
];
// quants - low 2 bits
uint8_t
scales
[
K_SCALE_SIZE
];
// scales, quantized with 6 bits
half
d
;
// super-block scale
}
block_q3_K
;
#define QR4_K 2
#define QI4_K (QK_K / (4 * QR4_K))
typedef
struct
{
half2
dm
;
// super-block scale for quantized scales/mins
uint8_t
scales
[
3
*
QK_K
/
64
];
// scales, quantized with 6 bits
uint8_t
qs
[
QK_K
/
2
];
// 4--bit quants
}
block_q4_K
;
#define QR5_K 2
#define QI5_K (QK_K / (4 * QR5_K))
typedef
struct
{
half2
dm
;
// super-block scale for quantized scales/mins
uint8_t
scales
[
K_SCALE_SIZE
];
// scales and mins, quantized with 6 bits
uint8_t
qh
[
QK_K
/
8
];
// quants, high bit
uint8_t
qs
[
QK_K
/
2
];
// quants, low 4 bits
}
block_q5_K
;
#define QR6_K 2
#define QI6_K (QK_K / (4 * QR6_K))
typedef
struct
{
uint8_t
ql
[
QK_K
/
2
];
// quants, lower 4 bits
uint8_t
qh
[
QK_K
/
4
];
// quants, upper 2 bits
int8_t
scales
[
QK_K
/
16
];
// scales
half
d
;
// delta
}
block_q6_K
;
#define QR2_XXS 8
#define QI2_XXS (QK_K / (4 * QR2_XXS))
typedef
struct
{
half
d
;
uint16_t
qs
[
QK_K
/
8
];
}
block_iq2_xxs
;
#define QR2_XS 8
#define QI2_XS (QK_K / (4 * QR2_XS))
typedef
struct
{
half
d
;
uint16_t
qs
[
QK_K
/
8
];
uint8_t
scales
[
QK_K
/
32
];
}
block_iq2_xs
;
#define QR2_S 8
#define QI2_S (QK_K / (4 * QR2_S))
typedef
struct
{
half
d
;
uint8_t
qs
[
QK_K
/
4
];
uint8_t
qh
[
QK_K
/
32
];
uint8_t
scales
[
QK_K
/
32
];
}
block_iq2_s
;
#define QR3_XXS 8
#define QI3_XXS (QK_K / (4 * QR3_XXS))
typedef
struct
{
half
d
;
uint8_t
qs
[
3
*
(
QK_K
/
8
)];
}
block_iq3_xxs
;
#define QR3_XS 8
#define QI3_XS (QK_K / (4 * QR3_XS))
#define IQ3S_N_SCALE QK_K / 64
typedef
struct
{
half
d
;
uint8_t
qs
[
QK_K
/
4
];
uint8_t
qh
[
QK_K
/
32
];
uint8_t
signs
[
QK_K
/
8
];
uint8_t
scales
[
IQ3S_N_SCALE
];
}
block_iq3_s
;
// 1.5625 bpw
#define QR1_S 8
#define QI1_S (QK_K / (4 * QR1_S))
typedef
struct
{
half
d
;
uint8_t
qs
[
QK_K
/
8
];
uint16_t
qh
[
QK_K
/
32
];
}
block_iq1_s
;
// 1.75 bpw
#define QR1_M 8
#define QI1_M (QK_K / (4 * QR1_M))
typedef
struct
{
uint8_t
qs
[
QK_K
/
8
];
// grid index, low 8 bits
uint8_t
qh
[
QK_K
/
16
];
// grid index, high 3 bits + grid shift bit (for two groups of 8)
uint8_t
scales
[
QK_K
/
32
];
// 3-bit block scales (4-bit if QK_K == 64)
}
block_iq1_m
;
// Used by IQ1_M quants
typedef
union
{
half
f16
;
uint16_t
u16
;
}
iq1m_scale_t
;
#define QK4_NL 32
#define QR4_NL 2
#define QI4_NL (QK4_NL / (4 * QR4_NL))
typedef
struct
{
half
d
;
uint8_t
qs
[
QK4_NL
/
2
];
}
block_iq4_nl
;
#define QR4_XS 8
#define QI4_XS (QK_K / (4 * QR4_XS))
typedef
struct
{
half
d
;
uint16_t
scales_h
;
uint8_t
scales_l
[
QK_K
/
64
];
uint8_t
qs
[
QK_K
/
2
];
}
block_iq4_xs
;
static
const
__device__
uint64_t
iq2xxs_grid
[
256
]
=
{
0x0808080808080808
,
0x080808080808082b
,
0x0808080808081919
,
0x0808080808082b08
,
0x0808080808082b2b
,
0x0808080808190819
,
0x0808080808191908
,
0x08080808082b0808
,
0x08080808082b082b
,
0x08080808082b2b08
,
0x08080808082b2b2b
,
0x0808080819080819
,
0x0808080819081908
,
0x0808080819190808
,
0x0808080819192b08
,
0x08080808192b0819
,
0x08080808192b1908
,
0x080808082b080808
,
0x080808082b08082b
,
0x080808082b082b2b
,
0x080808082b2b082b
,
0x0808081908080819
,
0x0808081908081908
,
0x0808081908190808
,
0x0808081908191919
,
0x0808081919080808
,
0x080808192b081908
,
0x080808192b192b08
,
0x0808082b08080808
,
0x0808082b0808082b
,
0x0808082b082b082b
,
0x0808082b2b08082b
,
0x0808190808080819
,
0x0808190808081908
,
0x0808190808190808
,
0x08081908082b0819
,
0x08081908082b1908
,
0x0808190819080808
,
0x080819081908082b
,
0x0808190819082b08
,
0x08081908192b0808
,
0x080819082b080819
,
0x080819082b081908
,
0x080819082b190808
,
0x080819082b2b1908
,
0x0808191908080808
,
0x080819190808082b
,
0x0808191908082b08
,
0x08081919082b0808
,
0x080819191908192b
,
0x08081919192b2b19
,
0x080819192b080808
,
0x080819192b190819
,
0x0808192b08082b19
,
0x0808192b08190808
,
0x0808192b19080808
,
0x0808192b2b081908
,
0x0808192b2b2b1908
,
0x08082b0808080808
,
0x08082b0808081919
,
0x08082b0808082b08
,
0x08082b0808191908
,
0x08082b08082b2b08
,
0x08082b0819080819
,
0x08082b0819081908
,
0x08082b0819190808
,
0x08082b081919082b
,
0x08082b082b082b08
,
0x08082b1908081908
,
0x08082b1919080808
,
0x08082b2b0808082b
,
0x08082b2b08191908
,
0x0819080808080819
,
0x0819080808081908
,
0x0819080808190808
,
0x08190808082b0819
,
0x0819080819080808
,
0x08190808192b0808
,
0x081908082b081908
,
0x081908082b190808
,
0x081908082b191919
,
0x0819081908080808
,
0x0819081908082b08
,
0x08190819082b0808
,
0x0819081919190808
,
0x0819081919192b2b
,
0x081908192b080808
,
0x0819082b082b1908
,
0x0819082b19081919
,
0x0819190808080808
,
0x0819190808082b08
,
0x08191908082b0808
,
0x08191908082b1919
,
0x0819190819082b19
,
0x081919082b080808
,
0x0819191908192b08
,
0x08191919192b082b
,
0x0819192b08080808
,
0x0819192b0819192b
,
0x08192b0808080819
,
0x08192b0808081908
,
0x08192b0808190808
,
0x08192b0819080808
,
0x08192b082b080819
,
0x08192b1908080808
,
0x08192b1908081919
,
0x08192b192b2b0808
,
0x08192b2b19190819
,
0x082b080808080808
,
0x082b08080808082b
,
0x082b080808082b2b
,
0x082b080819081908
,
0x082b0808192b0819
,
0x082b08082b080808
,
0x082b08082b08082b
,
0x082b0819082b2b19
,
0x082b081919082b08
,
0x082b082b08080808
,
0x082b082b0808082b
,
0x082b190808080819
,
0x082b190808081908
,
0x082b190808190808
,
0x082b190819080808
,
0x082b19081919192b
,
0x082b191908080808
,
0x082b191919080819
,
0x082b1919192b1908
,
0x082b192b2b190808
,
0x082b2b0808082b08
,
0x082b2b08082b0808
,
0x082b2b082b191908
,
0x082b2b2b19081908
,
0x1908080808080819
,
0x1908080808081908
,
0x1908080808190808
,
0x1908080808192b08
,
0x19080808082b0819
,
0x19080808082b1908
,
0x1908080819080808
,
0x1908080819082b08
,
0x190808081919192b
,
0x19080808192b0808
,
0x190808082b080819
,
0x190808082b081908
,
0x190808082b190808
,
0x1908081908080808
,
0x19080819082b0808
,
0x19080819192b0819
,
0x190808192b080808
,
0x190808192b081919
,
0x1908082b08080819
,
0x1908082b08190808
,
0x1908082b19082b08
,
0x1908082b1919192b
,
0x1908082b192b2b08
,
0x1908190808080808
,
0x1908190808082b08
,
0x19081908082b0808
,
0x190819082b080808
,
0x190819082b192b19
,
0x190819190819082b
,
0x19081919082b1908
,
0x1908192b08080808
,
0x19082b0808080819
,
0x19082b0808081908
,
0x19082b0808190808
,
0x19082b0819080808
,
0x19082b0819081919
,
0x19082b1908080808
,
0x19082b1919192b08
,
0x19082b19192b0819
,
0x19082b192b08082b
,
0x19082b2b19081919
,
0x19082b2b2b190808
,
0x1919080808080808
,
0x1919080808082b08
,
0x1919080808190819
,
0x1919080808192b19
,
0x19190808082b0808
,
0x191908082b080808
,
0x191908082b082b08
,
0x1919081908081908
,
0x191908191908082b
,
0x191908192b2b1908
,
0x1919082b2b190819
,
0x191919082b190808
,
0x191919082b19082b
,
0x1919191908082b2b
,
0x1919192b08080819
,
0x1919192b19191908
,
0x19192b0808080808
,
0x19192b0808190819
,
0x19192b0808192b19
,
0x19192b08192b1908
,
0x19192b1919080808
,
0x19192b2b08082b08
,
0x192b080808081908
,
0x192b080808190808
,
0x192b080819080808
,
0x192b0808192b2b08
,
0x192b081908080808
,
0x192b081919191919
,
0x192b082b08192b08
,
0x192b082b192b0808
,
0x192b190808080808
,
0x192b190808081919
,
0x192b191908190808
,
0x192b19190819082b
,
0x192b19192b081908
,
0x192b2b081908082b
,
0x2b08080808080808
,
0x2b0808080808082b
,
0x2b08080808082b2b
,
0x2b08080819080819
,
0x2b0808082b08082b
,
0x2b08081908081908
,
0x2b08081908192b08
,
0x2b08081919080808
,
0x2b08082b08190819
,
0x2b08190808080819
,
0x2b08190808081908
,
0x2b08190808190808
,
0x2b08190808191919
,
0x2b08190819080808
,
0x2b081908192b0808
,
0x2b08191908080808
,
0x2b0819191908192b
,
0x2b0819192b191908
,
0x2b08192b08082b19
,
0x2b08192b19080808
,
0x2b08192b192b0808
,
0x2b082b080808082b
,
0x2b082b1908081908
,
0x2b082b2b08190819
,
0x2b19080808081908
,
0x2b19080808190808
,
0x2b190808082b1908
,
0x2b19080819080808
,
0x2b1908082b2b0819
,
0x2b1908190819192b
,
0x2b1908192b080808
,
0x2b19082b19081919
,
0x2b19190808080808
,
0x2b191908082b082b
,
0x2b19190819081908
,
0x2b19191919190819
,
0x2b192b082b080819
,
0x2b192b19082b0808
,
0x2b2b08080808082b
,
0x2b2b080819190808
,
0x2b2b08082b081919
,
0x2b2b081908082b19
,
0x2b2b082b08080808
,
0x2b2b190808192b08
,
0x2b2b2b0819190808
,
0x2b2b2b1908081908
,
};
static
const
__device__
uint64_t
iq2xs_grid
[
512
]
=
{
0x0808080808080808
,
0x080808080808082b
,
0x0808080808081919
,
0x0808080808082b08
,
0x0808080808082b2b
,
0x0808080808190819
,
0x0808080808191908
,
0x080808080819192b
,
0x0808080808192b19
,
0x08080808082b0808
,
0x08080808082b082b
,
0x08080808082b1919
,
0x08080808082b2b08
,
0x0808080819080819
,
0x0808080819081908
,
0x080808081908192b
,
0x0808080819082b19
,
0x0808080819190808
,
0x080808081919082b
,
0x0808080819191919
,
0x0808080819192b08
,
0x08080808192b0819
,
0x08080808192b1908
,
0x080808082b080808
,
0x080808082b08082b
,
0x080808082b081919
,
0x080808082b082b08
,
0x080808082b190819
,
0x080808082b191908
,
0x080808082b192b19
,
0x080808082b2b0808
,
0x0808081908080819
,
0x0808081908081908
,
0x080808190808192b
,
0x0808081908082b19
,
0x0808081908190808
,
0x080808190819082b
,
0x0808081908191919
,
0x0808081908192b08
,
0x0808081908192b2b
,
0x08080819082b0819
,
0x08080819082b1908
,
0x0808081919080808
,
0x080808191908082b
,
0x0808081919081919
,
0x0808081919082b08
,
0x0808081919190819
,
0x0808081919191908
,
0x08080819192b0808
,
0x08080819192b2b08
,
0x080808192b080819
,
0x080808192b081908
,
0x080808192b190808
,
0x0808082b08080808
,
0x0808082b0808082b
,
0x0808082b08081919
,
0x0808082b08082b08
,
0x0808082b08190819
,
0x0808082b08191908
,
0x0808082b082b0808
,
0x0808082b19080819
,
0x0808082b19081908
,
0x0808082b19190808
,
0x0808082b19191919
,
0x0808082b2b080808
,
0x0808082b2b082b2b
,
0x0808190808080819
,
0x0808190808081908
,
0x080819080808192b
,
0x0808190808082b19
,
0x0808190808190808
,
0x080819080819082b
,
0x0808190808191919
,
0x0808190808192b08
,
0x08081908082b0819
,
0x08081908082b1908
,
0x0808190819080808
,
0x080819081908082b
,
0x0808190819081919
,
0x0808190819082b08
,
0x0808190819190819
,
0x0808190819191908
,
0x080819081919192b
,
0x08081908192b0808
,
0x080819082b080819
,
0x080819082b081908
,
0x080819082b190808
,
0x0808191908080808
,
0x080819190808082b
,
0x0808191908081919
,
0x0808191908082b08
,
0x0808191908190819
,
0x0808191908191908
,
0x08081919082b0808
,
0x0808191919080819
,
0x0808191919081908
,
0x0808191919190808
,
0x08081919192b0819
,
0x080819192b080808
,
0x0808192b08080819
,
0x0808192b08081908
,
0x0808192b08190808
,
0x0808192b082b192b
,
0x0808192b19080808
,
0x0808192b1908082b
,
0x0808192b2b081908
,
0x08082b0808080808
,
0x08082b080808082b
,
0x08082b0808081919
,
0x08082b0808082b08
,
0x08082b0808082b2b
,
0x08082b0808190819
,
0x08082b0808191908
,
0x08082b08082b0808
,
0x08082b08082b1919
,
0x08082b0819080819
,
0x08082b0819081908
,
0x08082b0819190808
,
0x08082b0819192b08
,
0x08082b082b080808
,
0x08082b082b2b0808
,
0x08082b082b2b2b2b
,
0x08082b1908080819
,
0x08082b1908081908
,
0x08082b1908190808
,
0x08082b1919080808
,
0x08082b192b080819
,
0x08082b192b082b19
,
0x08082b2b08080808
,
0x08082b2b082b0808
,
0x08082b2b082b2b08
,
0x08082b2b2b19192b
,
0x08082b2b2b2b0808
,
0x0819080808080819
,
0x0819080808081908
,
0x081908080808192b
,
0x0819080808082b19
,
0x0819080808190808
,
0x081908080819082b
,
0x0819080808191919
,
0x0819080808192b08
,
0x08190808082b0819
,
0x08190808082b1908
,
0x0819080819080808
,
0x081908081908082b
,
0x0819080819081919
,
0x0819080819082b08
,
0x0819080819190819
,
0x0819080819191908
,
0x08190808192b0808
,
0x08190808192b2b2b
,
0x081908082b080819
,
0x081908082b081908
,
0x081908082b190808
,
0x0819081908080808
,
0x081908190808082b
,
0x0819081908081919
,
0x0819081908082b08
,
0x0819081908190819
,
0x0819081908191908
,
0x08190819082b0808
,
0x0819081919080819
,
0x0819081919081908
,
0x0819081919190808
,
0x081908192b080808
,
0x081908192b191908
,
0x081908192b19192b
,
0x0819082b08080819
,
0x0819082b08081908
,
0x0819082b0808192b
,
0x0819082b08190808
,
0x0819082b19080808
,
0x0819082b192b0808
,
0x0819190808080808
,
0x081919080808082b
,
0x0819190808081919
,
0x0819190808082b08
,
0x0819190808190819
,
0x0819190808191908
,
0x08191908082b0808
,
0x0819190819080819
,
0x0819190819081908
,
0x0819190819082b19
,
0x0819190819190808
,
0x08191908192b1908
,
0x081919082b080808
,
0x0819191908080819
,
0x0819191908081908
,
0x0819191908190808
,
0x0819191919080808
,
0x0819192b08080808
,
0x0819192b08191908
,
0x0819192b19082b19
,
0x08192b0808080819
,
0x08192b0808081908
,
0x08192b0808190808
,
0x08192b080819082b
,
0x08192b0819080808
,
0x08192b0819191908
,
0x08192b082b08192b
,
0x08192b1908080808
,
0x08192b1908081919
,
0x08192b19192b192b
,
0x08192b2b19190819
,
0x08192b2b2b2b2b19
,
0x082b080808080808
,
0x082b08080808082b
,
0x082b080808081919
,
0x082b080808082b08
,
0x082b080808082b2b
,
0x082b080808190819
,
0x082b080808191908
,
0x082b0808082b0808
,
0x082b080819080819
,
0x082b080819081908
,
0x082b080819190808
,
0x082b08082b080808
,
0x082b08082b2b0808
,
0x082b081908080819
,
0x082b081908081908
,
0x082b081908190808
,
0x082b081919080808
,
0x082b081919082b08
,
0x082b0819192b1919
,
0x082b082b08080808
,
0x082b082b082b082b
,
0x082b082b2b080808
,
0x082b082b2b2b2b08
,
0x082b190808080819
,
0x082b190808081908
,
0x082b190808190808
,
0x082b1908082b2b19
,
0x082b190819080808
,
0x082b191908080808
,
0x082b191919080819
,
0x082b19191919082b
,
0x082b19192b192b19
,
0x082b192b08080819
,
0x082b192b08192b2b
,
0x082b192b2b2b192b
,
0x082b2b0808080808
,
0x082b2b0808082b08
,
0x082b2b0808082b2b
,
0x082b2b08082b0808
,
0x082b2b0819191919
,
0x082b2b082b082b08
,
0x082b2b082b2b082b
,
0x082b2b19192b2b08
,
0x082b2b192b190808
,
0x082b2b2b08082b08
,
0x082b2b2b082b0808
,
0x082b2b2b2b08082b
,
0x082b2b2b2b082b08
,
0x082b2b2b2b082b2b
,
0x1908080808080819
,
0x1908080808081908
,
0x190808080808192b
,
0x1908080808082b19
,
0x1908080808190808
,
0x190808080819082b
,
0x1908080808191919
,
0x1908080808192b08
,
0x19080808082b0819
,
0x19080808082b1908
,
0x1908080819080808
,
0x190808081908082b
,
0x1908080819081919
,
0x1908080819082b08
,
0x1908080819082b2b
,
0x1908080819190819
,
0x1908080819191908
,
0x19080808192b0808
,
0x19080808192b1919
,
0x190808082b080819
,
0x190808082b081908
,
0x190808082b190808
,
0x1908081908080808
,
0x190808190808082b
,
0x1908081908081919
,
0x1908081908082b08
,
0x1908081908190819
,
0x1908081908191908
,
0x19080819082b0808
,
0x1908081919080819
,
0x1908081919081908
,
0x1908081919190808
,
0x190808192b080808
,
0x190808192b081919
,
0x190808192b2b082b
,
0x1908082b08080819
,
0x1908082b08081908
,
0x1908082b08190808
,
0x1908082b0819082b
,
0x1908082b082b2b19
,
0x1908082b19080808
,
0x1908190808080808
,
0x190819080808082b
,
0x1908190808081919
,
0x1908190808082b08
,
0x1908190808190819
,
0x1908190808191908
,
0x1908190808192b19
,
0x19081908082b0808
,
0x1908190819080819
,
0x1908190819081908
,
0x1908190819190808
,
0x190819082b080808
,
0x190819082b191908
,
0x1908191908080819
,
0x1908191908081908
,
0x1908191908190808
,
0x19081919082b1908
,
0x1908191919080808
,
0x190819192b192b2b
,
0x1908192b08080808
,
0x1908192b08082b2b
,
0x1908192b19081908
,
0x1908192b19190808
,
0x19082b0808080819
,
0x19082b0808081908
,
0x19082b0808190808
,
0x19082b0819080808
,
0x19082b0819081919
,
0x19082b0819191908
,
0x19082b08192b082b
,
0x19082b1908080808
,
0x19082b1908190819
,
0x19082b1919081908
,
0x19082b1919190808
,
0x19082b19192b2b19
,
0x19082b2b08081908
,
0x1919080808080808
,
0x191908080808082b
,
0x1919080808081919
,
0x1919080808082b08
,
0x1919080808190819
,
0x1919080808191908
,
0x19190808082b0808
,
0x19190808082b2b08
,
0x1919080819080819
,
0x1919080819081908
,
0x1919080819190808
,
0x191908082b080808
,
0x1919081908080819
,
0x1919081908081908
,
0x1919081908190808
,
0x1919081908191919
,
0x1919081919080808
,
0x191908191908082b
,
0x1919082b08080808
,
0x1919082b19081908
,
0x1919082b2b2b2b2b
,
0x1919190808080819
,
0x1919190808081908
,
0x1919190808190808
,
0x19191908082b0819
,
0x1919190819080808
,
0x19191908192b0808
,
0x191919082b080819
,
0x191919082b2b0819
,
0x1919191908080808
,
0x1919191908082b08
,
0x191919192b080808
,
0x191919192b082b08
,
0x1919192b082b0819
,
0x1919192b192b2b08
,
0x1919192b2b2b0819
,
0x19192b0808080808
,
0x19192b0808191908
,
0x19192b0819080819
,
0x19192b0819190808
,
0x19192b082b192b19
,
0x19192b1908192b2b
,
0x19192b1919080808
,
0x19192b191908082b
,
0x19192b2b2b081919
,
0x192b080808080819
,
0x192b080808081908
,
0x192b080808190808
,
0x192b080819080808
,
0x192b080819191908
,
0x192b0808192b082b
,
0x192b08082b08192b
,
0x192b08082b2b2b19
,
0x192b081908080808
,
0x192b082b082b1908
,
0x192b082b19082b2b
,
0x192b082b2b19082b
,
0x192b190808080808
,
0x192b19080819192b
,
0x192b191908190808
,
0x192b191919080808
,
0x192b191919081919
,
0x192b19192b2b1908
,
0x192b2b0808080819
,
0x192b2b08192b2b2b
,
0x192b2b19082b1919
,
0x192b2b2b0808192b
,
0x192b2b2b19191908
,
0x192b2b2b192b082b
,
0x2b08080808080808
,
0x2b0808080808082b
,
0x2b08080808081919
,
0x2b08080808082b08
,
0x2b08080808190819
,
0x2b08080808191908
,
0x2b080808082b0808
,
0x2b080808082b2b2b
,
0x2b08080819080819
,
0x2b08080819081908
,
0x2b08080819190808
,
0x2b0808082b080808
,
0x2b0808082b08082b
,
0x2b0808082b2b2b08
,
0x2b0808082b2b2b2b
,
0x2b08081908080819
,
0x2b08081908081908
,
0x2b0808190808192b
,
0x2b08081908190808
,
0x2b08081919080808
,
0x2b08081919190819
,
0x2b08081919192b19
,
0x2b08082b08080808
,
0x2b08082b082b0808
,
0x2b08082b2b080808
,
0x2b08082b2b08082b
,
0x2b08082b2b2b0808
,
0x2b08082b2b2b2b08
,
0x2b08190808080819
,
0x2b08190808081908
,
0x2b08190808190808
,
0x2b0819080819082b
,
0x2b08190808191919
,
0x2b08190819080808
,
0x2b081908192b0808
,
0x2b0819082b082b19
,
0x2b08191908080808
,
0x2b08191919081908
,
0x2b0819192b2b1919
,
0x2b08192b08192b08
,
0x2b08192b192b2b2b
,
0x2b082b0808080808
,
0x2b082b0808082b08
,
0x2b082b08082b1919
,
0x2b082b0819192b2b
,
0x2b082b082b080808
,
0x2b082b082b08082b
,
0x2b082b082b2b2b08
,
0x2b082b190808192b
,
0x2b082b2b082b082b
,
0x2b082b2b2b080808
,
0x2b082b2b2b082b08
,
0x2b082b2b2b19192b
,
0x2b082b2b2b2b2b08
,
0x2b19080808080819
,
0x2b19080808081908
,
0x2b19080808190808
,
0x2b19080819080808
,
0x2b1908081919192b
,
0x2b1908082b081908
,
0x2b19081908080808
,
0x2b190819082b082b
,
0x2b190819192b1908
,
0x2b19082b1919192b
,
0x2b19082b2b082b19
,
0x2b19190808080808
,
0x2b19190808081919
,
0x2b19190819081908
,
0x2b19190819190808
,
0x2b19190819192b08
,
0x2b191919082b2b19
,
0x2b1919192b190808
,
0x2b1919192b19082b
,
0x2b19192b19080819
,
0x2b192b0819190819
,
0x2b192b082b2b192b
,
0x2b192b1919082b19
,
0x2b192b2b08191919
,
0x2b192b2b192b0808
,
0x2b2b080808080808
,
0x2b2b08080808082b
,
0x2b2b080808082b08
,
0x2b2b080808082b2b
,
0x2b2b0808082b0808
,
0x2b2b0808082b2b2b
,
0x2b2b08082b2b0808
,
0x2b2b081919190819
,
0x2b2b081919192b19
,
0x2b2b08192b2b192b
,
0x2b2b082b08080808
,
0x2b2b082b0808082b
,
0x2b2b082b08082b08
,
0x2b2b082b082b2b2b
,
0x2b2b082b2b080808
,
0x2b2b082b2b2b0808
,
0x2b2b190819080808
,
0x2b2b19082b191919
,
0x2b2b192b192b1919
,
0x2b2b192b2b192b08
,
0x2b2b2b0808082b2b
,
0x2b2b2b08082b0808
,
0x2b2b2b08082b082b
,
0x2b2b2b08082b2b08
,
0x2b2b2b082b2b0808
,
0x2b2b2b082b2b2b08
,
0x2b2b2b1908081908
,
0x2b2b2b192b081908
,
0x2b2b2b192b08192b
,
0x2b2b2b2b082b2b08
,
0x2b2b2b2b082b2b2b
,
0x2b2b2b2b2b190819
,
0x2b2b2b2b2b2b2b2b
,
};
static
const
__device__
uint64_t
iq2s_grid
[
1024
]
=
{
0x0808080808080808
,
0x080808080808082b
,
0x0808080808081919
,
0x0808080808082b08
,
0x0808080808082b2b
,
0x0808080808190819
,
0x0808080808191908
,
0x080808080819192b
,
0x0808080808192b19
,
0x08080808082b0808
,
0x08080808082b082b
,
0x08080808082b1919
,
0x08080808082b2b08
,
0x0808080819080819
,
0x0808080819081908
,
0x080808081908192b
,
0x0808080819082b19
,
0x0808080819190808
,
0x080808081919082b
,
0x0808080819191919
,
0x0808080819192b08
,
0x08080808192b0819
,
0x08080808192b1908
,
0x08080808192b192b
,
0x08080808192b2b19
,
0x080808082b080808
,
0x080808082b08082b
,
0x080808082b081919
,
0x080808082b082b08
,
0x080808082b190819
,
0x080808082b191908
,
0x080808082b2b0808
,
0x080808082b2b1919
,
0x080808082b2b2b2b
,
0x0808081908080819
,
0x0808081908081908
,
0x080808190808192b
,
0x0808081908082b19
,
0x0808081908190808
,
0x080808190819082b
,
0x0808081908191919
,
0x0808081908192b08
,
0x08080819082b0819
,
0x08080819082b1908
,
0x0808081919080808
,
0x080808191908082b
,
0x0808081919081919
,
0x0808081919082b08
,
0x0808081919190819
,
0x0808081919191908
,
0x080808191919192b
,
0x0808081919192b19
,
0x08080819192b0808
,
0x08080819192b1919
,
0x08080819192b2b08
,
0x080808192b080819
,
0x080808192b081908
,
0x080808192b190808
,
0x080808192b19082b
,
0x080808192b191919
,
0x080808192b2b0819
,
0x080808192b2b1908
,
0x0808082b08080808
,
0x0808082b0808082b
,
0x0808082b08081919
,
0x0808082b08082b08
,
0x0808082b08190819
,
0x0808082b08191908
,
0x0808082b082b0808
,
0x0808082b082b2b2b
,
0x0808082b19080819
,
0x0808082b19081908
,
0x0808082b1908192b
,
0x0808082b19082b19
,
0x0808082b19190808
,
0x0808082b19191919
,
0x0808082b2b080808
,
0x0808082b2b081919
,
0x0808082b2b082b2b
,
0x0808082b2b191908
,
0x0808082b2b2b082b
,
0x0808190808080819
,
0x0808190808081908
,
0x080819080808192b
,
0x0808190808082b19
,
0x0808190808190808
,
0x080819080819082b
,
0x0808190808191919
,
0x0808190808192b08
,
0x08081908082b0819
,
0x08081908082b1908
,
0x08081908082b192b
,
0x08081908082b2b19
,
0x0808190819080808
,
0x080819081908082b
,
0x0808190819081919
,
0x0808190819082b08
,
0x0808190819082b2b
,
0x0808190819190819
,
0x0808190819191908
,
0x080819081919192b
,
0x0808190819192b19
,
0x08081908192b0808
,
0x08081908192b082b
,
0x08081908192b1919
,
0x080819082b080819
,
0x080819082b081908
,
0x080819082b08192b
,
0x080819082b082b19
,
0x080819082b190808
,
0x080819082b191919
,
0x080819082b192b08
,
0x080819082b2b0819
,
0x080819082b2b1908
,
0x0808191908080808
,
0x080819190808082b
,
0x0808191908081919
,
0x0808191908082b08
,
0x0808191908082b2b
,
0x0808191908190819
,
0x0808191908191908
,
0x080819190819192b
,
0x0808191908192b19
,
0x08081919082b0808
,
0x08081919082b1919
,
0x08081919082b2b08
,
0x0808191919080819
,
0x0808191919081908
,
0x080819191908192b
,
0x0808191919082b19
,
0x0808191919190808
,
0x080819191919082b
,
0x0808191919191919
,
0x0808191919192b08
,
0x08081919192b0819
,
0x08081919192b1908
,
0x080819192b080808
,
0x080819192b08082b
,
0x080819192b081919
,
0x080819192b082b08
,
0x080819192b190819
,
0x080819192b191908
,
0x080819192b2b0808
,
0x0808192b08080819
,
0x0808192b08081908
,
0x0808192b0808192b
,
0x0808192b08082b19
,
0x0808192b08190808
,
0x0808192b08191919
,
0x0808192b19080808
,
0x0808192b19081919
,
0x0808192b19082b08
,
0x0808192b19190819
,
0x0808192b19191908
,
0x0808192b192b0808
,
0x0808192b2b080819
,
0x0808192b2b081908
,
0x0808192b2b190808
,
0x08082b0808080808
,
0x08082b080808082b
,
0x08082b0808081919
,
0x08082b0808082b08
,
0x08082b0808190819
,
0x08082b0808191908
,
0x08082b080819192b
,
0x08082b0808192b19
,
0x08082b08082b0808
,
0x08082b08082b1919
,
0x08082b08082b2b2b
,
0x08082b0819080819
,
0x08082b0819081908
,
0x08082b081908192b
,
0x08082b0819082b19
,
0x08082b0819190808
,
0x08082b081919082b
,
0x08082b0819191919
,
0x08082b0819192b08
,
0x08082b08192b0819
,
0x08082b08192b1908
,
0x08082b082b080808
,
0x08082b082b081919
,
0x08082b082b191908
,
0x08082b082b2b2b2b
,
0x08082b1908080819
,
0x08082b1908081908
,
0x08082b1908190808
,
0x08082b190819082b
,
0x08082b1908191919
,
0x08082b1908192b08
,
0x08082b19082b0819
,
0x08082b1919080808
,
0x08082b1919081919
,
0x08082b1919082b08
,
0x08082b1919190819
,
0x08082b1919191908
,
0x08082b19192b0808
,
0x08082b192b080819
,
0x08082b192b190808
,
0x08082b2b08080808
,
0x08082b2b08190819
,
0x08082b2b08191908
,
0x08082b2b082b082b
,
0x08082b2b082b2b08
,
0x08082b2b082b2b2b
,
0x08082b2b19190808
,
0x08082b2b2b192b19
,
0x0819080808080819
,
0x0819080808081908
,
0x081908080808192b
,
0x0819080808082b19
,
0x0819080808190808
,
0x081908080819082b
,
0x0819080808191919
,
0x0819080808192b08
,
0x08190808082b0819
,
0x08190808082b1908
,
0x08190808082b192b
,
0x0819080819080808
,
0x081908081908082b
,
0x0819080819081919
,
0x0819080819082b08
,
0x0819080819190819
,
0x0819080819191908
,
0x081908081919192b
,
0x0819080819192b19
,
0x08190808192b0808
,
0x08190808192b082b
,
0x08190808192b1919
,
0x08190808192b2b08
,
0x081908082b080819
,
0x081908082b081908
,
0x081908082b08192b
,
0x081908082b190808
,
0x081908082b191919
,
0x081908082b192b08
,
0x081908082b2b0819
,
0x081908082b2b1908
,
0x0819081908080808
,
0x081908190808082b
,
0x0819081908081919
,
0x0819081908082b08
,
0x0819081908082b2b
,
0x0819081908190819
,
0x0819081908191908
,
0x081908190819192b
,
0x0819081908192b19
,
0x08190819082b0808
,
0x08190819082b082b
,
0x08190819082b1919
,
0x08190819082b2b08
,
0x0819081919080819
,
0x0819081919081908
,
0x081908191908192b
,
0x0819081919082b19
,
0x0819081919190808
,
0x081908191919082b
,
0x0819081919191919
,
0x0819081919192b08
,
0x08190819192b0819
,
0x08190819192b1908
,
0x081908192b080808
,
0x081908192b08082b
,
0x081908192b081919
,
0x081908192b082b08
,
0x081908192b190819
,
0x081908192b191908
,
0x0819082b08080819
,
0x0819082b08081908
,
0x0819082b08082b19
,
0x0819082b08190808
,
0x0819082b08191919
,
0x0819082b082b0819
,
0x0819082b082b1908
,
0x0819082b19080808
,
0x0819082b19081919
,
0x0819082b19190819
,
0x0819082b19191908
,
0x0819082b2b080819
,
0x0819082b2b081908
,
0x0819082b2b190808
,
0x0819190808080808
,
0x081919080808082b
,
0x0819190808081919
,
0x0819190808082b08
,
0x0819190808190819
,
0x0819190808191908
,
0x081919080819192b
,
0x0819190808192b19
,
0x08191908082b0808
,
0x08191908082b1919
,
0x08191908082b2b08
,
0x0819190819080819
,
0x0819190819081908
,
0x081919081908192b
,
0x0819190819082b19
,
0x0819190819190808
,
0x081919081919082b
,
0x0819190819191919
,
0x0819190819192b08
,
0x08191908192b0819
,
0x08191908192b1908
,
0x081919082b080808
,
0x081919082b08082b
,
0x081919082b081919
,
0x081919082b082b08
,
0x081919082b190819
,
0x081919082b191908
,
0x081919082b2b0808
,
0x0819191908080819
,
0x0819191908081908
,
0x081919190808192b
,
0x0819191908082b19
,
0x0819191908190808
,
0x081919190819082b
,
0x0819191908191919
,
0x0819191908192b08
,
0x08191919082b0819
,
0x08191919082b1908
,
0x0819191919080808
,
0x081919191908082b
,
0x0819191919081919
,
0x0819191919082b08
,
0x0819191919190819
,
0x0819191919191908
,
0x08191919192b0808
,
0x081919192b080819
,
0x081919192b081908
,
0x081919192b190808
,
0x0819192b08080808
,
0x0819192b08081919
,
0x0819192b08082b08
,
0x0819192b08190819
,
0x0819192b08191908
,
0x0819192b082b0808
,
0x0819192b19080819
,
0x0819192b19081908
,
0x0819192b19190808
,
0x0819192b2b080808
,
0x0819192b2b2b2b2b
,
0x08192b0808080819
,
0x08192b0808081908
,
0x08192b080808192b
,
0x08192b0808082b19
,
0x08192b0808190808
,
0x08192b0808191919
,
0x08192b0808192b08
,
0x08192b08082b0819
,
0x08192b0819080808
,
0x08192b081908082b
,
0x08192b0819081919
,
0x08192b0819082b08
,
0x08192b0819190819
,
0x08192b0819191908
,
0x08192b08192b0808
,
0x08192b082b080819
,
0x08192b082b081908
,
0x08192b1908080808
,
0x08192b190808082b
,
0x08192b1908081919
,
0x08192b1908082b08
,
0x08192b1908190819
,
0x08192b1908191908
,
0x08192b19082b0808
,
0x08192b1919080819
,
0x08192b1919081908
,
0x08192b1919190808
,
0x08192b19192b2b19
,
0x08192b192b2b082b
,
0x08192b2b08081908
,
0x08192b2b08190808
,
0x08192b2b19080808
,
0x08192b2b1919192b
,
0x082b080808080808
,
0x082b08080808082b
,
0x082b080808081919
,
0x082b080808082b08
,
0x082b080808190819
,
0x082b080808191908
,
0x082b08080819192b
,
0x082b080808192b19
,
0x082b0808082b0808
,
0x082b0808082b1919
,
0x082b0808082b2b2b
,
0x082b080819080819
,
0x082b080819081908
,
0x082b080819190808
,
0x082b08081919082b
,
0x082b080819191919
,
0x082b0808192b1908
,
0x082b08082b080808
,
0x082b08082b082b2b
,
0x082b08082b191908
,
0x082b08082b2b2b2b
,
0x082b081908080819
,
0x082b081908081908
,
0x082b081908190808
,
0x082b08190819082b
,
0x082b081908191919
,
0x082b0819082b0819
,
0x082b081919080808
,
0x082b08191908082b
,
0x082b081919081919
,
0x082b081919190819
,
0x082b081919191908
,
0x082b0819192b0808
,
0x082b08192b080819
,
0x082b08192b081908
,
0x082b08192b190808
,
0x082b082b08080808
,
0x082b082b08082b2b
,
0x082b082b082b082b
,
0x082b082b082b2b08
,
0x082b082b082b2b2b
,
0x082b082b19081908
,
0x082b082b19190808
,
0x082b082b2b082b08
,
0x082b082b2b082b2b
,
0x082b082b2b2b2b08
,
0x082b190808080819
,
0x082b190808081908
,
0x082b19080808192b
,
0x082b190808082b19
,
0x082b190808190808
,
0x082b190808191919
,
0x082b190808192b08
,
0x082b1908082b0819
,
0x082b1908082b1908
,
0x082b190819080808
,
0x082b19081908082b
,
0x082b190819081919
,
0x082b190819082b08
,
0x082b190819190819
,
0x082b190819191908
,
0x082b1908192b0808
,
0x082b19082b080819
,
0x082b19082b081908
,
0x082b19082b190808
,
0x082b191908080808
,
0x082b191908081919
,
0x082b191908082b08
,
0x082b191908190819
,
0x082b191908191908
,
0x082b1919082b0808
,
0x082b191919080819
,
0x082b191919081908
,
0x082b191919190808
,
0x082b1919192b192b
,
0x082b19192b080808
,
0x082b192b08080819
,
0x082b192b08081908
,
0x082b192b08190808
,
0x082b192b19080808
,
0x082b192b19192b19
,
0x082b2b0808080808
,
0x082b2b0808081919
,
0x082b2b0808190819
,
0x082b2b0808191908
,
0x082b2b0819080819
,
0x082b2b0819081908
,
0x082b2b0819190808
,
0x082b2b082b082b2b
,
0x082b2b082b2b2b2b
,
0x082b2b1908080819
,
0x082b2b1908081908
,
0x082b2b1908190808
,
0x082b2b192b191919
,
0x082b2b2b08082b2b
,
0x082b2b2b082b082b
,
0x082b2b2b192b1908
,
0x082b2b2b2b082b08
,
0x082b2b2b2b082b2b
,
0x1908080808080819
,
0x1908080808081908
,
0x190808080808192b
,
0x1908080808082b19
,
0x1908080808190808
,
0x190808080819082b
,
0x1908080808191919
,
0x1908080808192b08
,
0x1908080808192b2b
,
0x19080808082b0819
,
0x19080808082b1908
,
0x19080808082b192b
,
0x1908080819080808
,
0x190808081908082b
,
0x1908080819081919
,
0x1908080819082b08
,
0x1908080819082b2b
,
0x1908080819190819
,
0x1908080819191908
,
0x190808081919192b
,
0x1908080819192b19
,
0x19080808192b0808
,
0x19080808192b082b
,
0x19080808192b1919
,
0x190808082b080819
,
0x190808082b081908
,
0x190808082b190808
,
0x190808082b191919
,
0x190808082b192b08
,
0x190808082b2b0819
,
0x190808082b2b1908
,
0x1908081908080808
,
0x190808190808082b
,
0x1908081908081919
,
0x1908081908082b08
,
0x1908081908190819
,
0x1908081908191908
,
0x190808190819192b
,
0x1908081908192b19
,
0x19080819082b0808
,
0x19080819082b082b
,
0x19080819082b1919
,
0x1908081919080819
,
0x1908081919081908
,
0x190808191908192b
,
0x1908081919082b19
,
0x1908081919190808
,
0x190808191919082b
,
0x1908081919191919
,
0x1908081919192b08
,
0x19080819192b0819
,
0x19080819192b1908
,
0x190808192b080808
,
0x190808192b08082b
,
0x190808192b081919
,
0x190808192b082b08
,
0x190808192b190819
,
0x190808192b191908
,
0x190808192b2b0808
,
0x1908082b08080819
,
0x1908082b08081908
,
0x1908082b08190808
,
0x1908082b0819082b
,
0x1908082b08191919
,
0x1908082b08192b08
,
0x1908082b082b1908
,
0x1908082b19080808
,
0x1908082b19081919
,
0x1908082b19082b08
,
0x1908082b19190819
,
0x1908082b19191908
,
0x1908082b192b0808
,
0x1908082b2b080819
,
0x1908082b2b081908
,
0x1908190808080808
,
0x190819080808082b
,
0x1908190808081919
,
0x1908190808082b08
,
0x1908190808082b2b
,
0x1908190808190819
,
0x1908190808191908
,
0x190819080819192b
,
0x1908190808192b19
,
0x19081908082b0808
,
0x19081908082b082b
,
0x19081908082b1919
,
0x19081908082b2b08
,
0x1908190819080819
,
0x1908190819081908
,
0x190819081908192b
,
0x1908190819082b19
,
0x1908190819190808
,
0x190819081919082b
,
0x1908190819191919
,
0x1908190819192b08
,
0x19081908192b0819
,
0x19081908192b1908
,
0x190819082b080808
,
0x190819082b08082b
,
0x190819082b081919
,
0x190819082b082b08
,
0x190819082b190819
,
0x190819082b191908
,
0x190819082b2b0808
,
0x1908191908080819
,
0x1908191908081908
,
0x190819190808192b
,
0x1908191908082b19
,
0x1908191908190808
,
0x190819190819082b
,
0x1908191908191919
,
0x1908191908192b08
,
0x19081919082b0819
,
0x19081919082b1908
,
0x1908191919080808
,
0x190819191908082b
,
0x1908191919081919
,
0x1908191919082b08
,
0x1908191919190819
,
0x1908191919191908
,
0x19081919192b0808
,
0x19081919192b2b2b
,
0x190819192b080819
,
0x190819192b081908
,
0x190819192b190808
,
0x1908192b08080808
,
0x1908192b0808082b
,
0x1908192b08081919
,
0x1908192b08082b08
,
0x1908192b08190819
,
0x1908192b08191908
,
0x1908192b082b0808
,
0x1908192b19080819
,
0x1908192b19081908
,
0x1908192b19190808
,
0x1908192b2b080808
,
0x1908192b2b2b1919
,
0x19082b0808080819
,
0x19082b0808081908
,
0x19082b0808082b19
,
0x19082b0808190808
,
0x19082b080819082b
,
0x19082b0808191919
,
0x19082b0808192b08
,
0x19082b08082b0819
,
0x19082b08082b1908
,
0x19082b0819080808
,
0x19082b081908082b
,
0x19082b0819081919
,
0x19082b0819082b08
,
0x19082b0819190819
,
0x19082b0819191908
,
0x19082b08192b0808
,
0x19082b082b081908
,
0x19082b082b190808
,
0x19082b1908080808
,
0x19082b190808082b
,
0x19082b1908081919
,
0x19082b1908082b08
,
0x19082b1908190819
,
0x19082b1908191908
,
0x19082b19082b0808
,
0x19082b1919080819
,
0x19082b1919081908
,
0x19082b1919190808
,
0x19082b192b080808
,
0x19082b192b19192b
,
0x19082b2b08080819
,
0x19082b2b08081908
,
0x19082b2b08190808
,
0x19082b2b19080808
,
0x1919080808080808
,
0x191908080808082b
,
0x1919080808081919
,
0x1919080808082b08
,
0x1919080808190819
,
0x1919080808191908
,
0x191908080819192b
,
0x1919080808192b19
,
0x19190808082b0808
,
0x19190808082b082b
,
0x19190808082b1919
,
0x19190808082b2b08
,
0x1919080819080819
,
0x1919080819081908
,
0x191908081908192b
,
0x1919080819082b19
,
0x1919080819190808
,
0x191908081919082b
,
0x1919080819191919
,
0x1919080819192b08
,
0x19190808192b0819
,
0x19190808192b1908
,
0x191908082b080808
,
0x191908082b08082b
,
0x191908082b081919
,
0x191908082b082b08
,
0x191908082b190819
,
0x191908082b191908
,
0x1919081908080819
,
0x1919081908081908
,
0x191908190808192b
,
0x1919081908082b19
,
0x1919081908190808
,
0x191908190819082b
,
0x1919081908191919
,
0x1919081908192b08
,
0x19190819082b0819
,
0x19190819082b1908
,
0x1919081919080808
,
0x191908191908082b
,
0x1919081919081919
,
0x1919081919082b08
,
0x1919081919190819
,
0x1919081919191908
,
0x19190819192b0808
,
0x191908192b080819
,
0x191908192b081908
,
0x191908192b190808
,
0x1919082b08080808
,
0x1919082b08081919
,
0x1919082b08082b08
,
0x1919082b08190819
,
0x1919082b08191908
,
0x1919082b082b0808
,
0x1919082b19080819
,
0x1919082b19081908
,
0x1919082b19190808
,
0x1919082b192b2b19
,
0x1919082b2b080808
,
0x1919190808080819
,
0x1919190808081908
,
0x191919080808192b
,
0x1919190808082b19
,
0x1919190808190808
,
0x191919080819082b
,
0x1919190808191919
,
0x1919190808192b08
,
0x19191908082b0819
,
0x19191908082b1908
,
0x1919190819080808
,
0x191919081908082b
,
0x1919190819081919
,
0x1919190819082b08
,
0x1919190819190819
,
0x1919190819191908
,
0x19191908192b0808
,
0x191919082b080819
,
0x191919082b081908
,
0x191919082b190808
,
0x1919191908080808
,
0x191919190808082b
,
0x1919191908081919
,
0x1919191908082b08
,
0x1919191908190819
,
0x1919191908191908
,
0x19191919082b0808
,
0x1919191919080819
,
0x1919191919081908
,
0x1919191919190808
,
0x191919192b080808
,
0x1919192b08080819
,
0x1919192b08081908
,
0x1919192b08190808
,
0x1919192b082b192b
,
0x1919192b19080808
,
0x19192b0808080808
,
0x19192b080808082b
,
0x19192b0808081919
,
0x19192b0808082b08
,
0x19192b0808190819
,
0x19192b0808191908
,
0x19192b08082b0808
,
0x19192b0819080819
,
0x19192b0819081908
,
0x19192b0819190808
,
0x19192b0819192b2b
,
0x19192b082b080808
,
0x19192b1908080819
,
0x19192b1908081908
,
0x19192b1908190808
,
0x19192b1919080808
,
0x19192b2b08080808
,
0x19192b2b08192b19
,
0x19192b2b2b081919
,
0x19192b2b2b2b2b08
,
0x192b080808080819
,
0x192b080808081908
,
0x192b08080808192b
,
0x192b080808190808
,
0x192b08080819082b
,
0x192b080808191919
,
0x192b080808192b08
,
0x192b0808082b0819
,
0x192b0808082b1908
,
0x192b080819080808
,
0x192b080819081919
,
0x192b080819082b08
,
0x192b080819190819
,
0x192b080819191908
,
0x192b0808192b0808
,
0x192b08082b081908
,
0x192b08082b190808
,
0x192b081908080808
,
0x192b08190808082b
,
0x192b081908081919
,
0x192b081908082b08
,
0x192b081908190819
,
0x192b081908191908
,
0x192b0819082b0808
,
0x192b081919080819
,
0x192b081919081908
,
0x192b081919190808
,
0x192b08192b080808
,
0x192b08192b192b19
,
0x192b082b08081908
,
0x192b082b08190808
,
0x192b082b19080808
,
0x192b082b1919192b
,
0x192b082b2b2b0819
,
0x192b190808080808
,
0x192b190808081919
,
0x192b190808082b08
,
0x192b190808190819
,
0x192b190808191908
,
0x192b1908082b0808
,
0x192b190819080819
,
0x192b190819081908
,
0x192b190819190808
,
0x192b19082b080808
,
0x192b191908080819
,
0x192b191908081908
,
0x192b191908190808
,
0x192b191919080808
,
0x192b191919082b2b
,
0x192b1919192b2b08
,
0x192b19192b19082b
,
0x192b192b08080808
,
0x192b192b2b191908
,
0x192b2b0808080819
,
0x192b2b0808081908
,
0x192b2b0808190808
,
0x192b2b08192b1919
,
0x192b2b082b192b08
,
0x192b2b1908080808
,
0x192b2b19082b2b2b
,
0x192b2b2b1908082b
,
0x192b2b2b2b2b0819
,
0x2b08080808080808
,
0x2b0808080808082b
,
0x2b08080808081919
,
0x2b08080808082b08
,
0x2b08080808190819
,
0x2b08080808191908
,
0x2b08080808192b19
,
0x2b080808082b0808
,
0x2b080808082b1919
,
0x2b08080819080819
,
0x2b08080819081908
,
0x2b08080819190808
,
0x2b0808081919082b
,
0x2b08080819191919
,
0x2b08080819192b08
,
0x2b080808192b0819
,
0x2b0808082b080808
,
0x2b0808082b081919
,
0x2b0808082b190819
,
0x2b0808082b191908
,
0x2b08081908080819
,
0x2b08081908081908
,
0x2b08081908082b19
,
0x2b08081908190808
,
0x2b0808190819082b
,
0x2b08081908191919
,
0x2b08081908192b08
,
0x2b080819082b0819
,
0x2b080819082b1908
,
0x2b08081919080808
,
0x2b0808191908082b
,
0x2b08081919081919
,
0x2b08081919082b08
,
0x2b08081919190819
,
0x2b08081919191908
,
0x2b0808192b080819
,
0x2b0808192b081908
,
0x2b0808192b190808
,
0x2b0808192b2b2b19
,
0x2b08082b08080808
,
0x2b08082b08081919
,
0x2b08082b08082b2b
,
0x2b08082b08190819
,
0x2b08082b08191908
,
0x2b08082b19080819
,
0x2b08082b19081908
,
0x2b08082b19190808
,
0x2b08190808080819
,
0x2b08190808081908
,
0x2b0819080808192b
,
0x2b08190808082b19
,
0x2b08190808190808
,
0x2b0819080819082b
,
0x2b08190808191919
,
0x2b08190808192b08
,
0x2b081908082b0819
,
0x2b08190819080808
,
0x2b0819081908082b
,
0x2b08190819081919
,
0x2b08190819082b08
,
0x2b08190819190819
,
0x2b08190819191908
,
0x2b081908192b0808
,
0x2b0819082b080819
,
0x2b0819082b081908
,
0x2b0819082b190808
,
0x2b08191908080808
,
0x2b0819190808082b
,
0x2b08191908081919
,
0x2b08191908082b08
,
0x2b08191908190819
,
0x2b08191908191908
,
0x2b081919082b0808
,
0x2b08191919080819
,
0x2b08191919081908
,
0x2b08191919190808
,
0x2b0819192b080808
,
0x2b0819192b082b2b
,
0x2b08192b08080819
,
0x2b08192b08081908
,
0x2b08192b08190808
,
0x2b08192b082b2b19
,
0x2b08192b19080808
,
0x2b082b0808080808
,
0x2b082b0808081919
,
0x2b082b0808190819
,
0x2b082b0808191908
,
0x2b082b0819080819
,
0x2b082b0819081908
,
0x2b082b0819190808
,
0x2b082b082b2b082b
,
0x2b082b1908080819
,
0x2b082b1908081908
,
0x2b082b1919080808
,
0x2b082b19192b1919
,
0x2b082b2b082b082b
,
0x2b082b2b19192b08
,
0x2b082b2b19192b2b
,
0x2b082b2b2b08082b
,
0x2b082b2b2b2b082b
,
0x2b19080808080819
,
0x2b19080808081908
,
0x2b19080808082b19
,
0x2b19080808190808
,
0x2b1908080819082b
,
0x2b19080808191919
,
0x2b19080808192b08
,
0x2b190808082b1908
,
0x2b19080819080808
,
0x2b1908081908082b
,
0x2b19080819081919
,
0x2b19080819082b08
,
0x2b19080819190819
,
0x2b19080819191908
,
0x2b190808192b0808
,
0x2b1908082b080819
,
0x2b1908082b081908
,
0x2b1908082b190808
,
0x2b19081908080808
,
0x2b19081908081919
,
0x2b19081908190819
,
0x2b19081908191908
,
0x2b19081919080819
,
0x2b19081919081908
,
0x2b19081919190808
,
0x2b19081919192b2b
,
0x2b19082b08080819
,
0x2b19082b08081908
,
0x2b19082b08190808
,
0x2b19082b19080808
,
0x2b19082b2b2b192b
,
0x2b19190808080808
,
0x2b1919080808082b
,
0x2b19190808081919
,
0x2b19190808082b08
,
0x2b19190808190819
,
0x2b19190808191908
,
0x2b191908082b0808
,
0x2b19190819080819
,
0x2b19190819081908
,
0x2b19190819190808
,
0x2b1919082b080808
,
0x2b1919082b19192b
,
0x2b19191908080819
,
0x2b19191908081908
,
0x2b19191908190808
,
0x2b19191919080808
,
0x2b1919192b192b08
,
0x2b1919192b2b0819
,
0x2b19192b08080808
,
0x2b19192b1908192b
,
0x2b19192b192b1908
,
0x2b192b0808080819
,
0x2b192b0808081908
,
0x2b192b0808190808
,
0x2b192b08082b192b
,
0x2b192b0819080808
,
0x2b192b082b2b2b19
,
0x2b192b1908080808
,
0x2b192b1919082b19
,
0x2b192b191919082b
,
0x2b192b2b2b190808
,
0x2b2b080808080808
,
0x2b2b080808081919
,
0x2b2b080808082b2b
,
0x2b2b080808191908
,
0x2b2b0808082b082b
,
0x2b2b0808082b2b2b
,
0x2b2b080819080819
,
0x2b2b080819081908
,
0x2b2b080819190808
,
0x2b2b08082b2b082b
,
0x2b2b08082b2b2b2b
,
0x2b2b081919080808
,
0x2b2b0819192b1919
,
0x2b2b082b0808082b
,
0x2b2b082b08082b2b
,
0x2b2b082b082b082b
,
0x2b2b082b082b2b08
,
0x2b2b082b082b2b2b
,
0x2b2b082b2b08082b
,
0x2b2b082b2b082b08
,
0x2b2b082b2b082b2b
,
0x2b2b082b2b2b2b08
,
0x2b2b190808080819
,
0x2b2b190808081908
,
0x2b2b190808190808
,
0x2b2b190819080808
,
0x2b2b19082b082b19
,
0x2b2b19082b2b1908
,
0x2b2b191908080808
,
0x2b2b191908192b19
,
0x2b2b192b19190819
,
0x2b2b2b0808082b2b
,
0x2b2b2b08082b2b08
,
0x2b2b2b082b2b082b
,
0x2b2b2b1919191908
,
0x2b2b2b192b08192b
,
0x2b2b2b2b08082b08
,
0x2b2b2b2b08082b2b
,
0x2b2b2b2b082b0808
,
0x2b2b2b2b082b082b
,
0x2b2b2b2b082b2b08
,
0x2b2b2b2b2b082b08
,
0x2b2b2b2b2b2b2b2b
,
};
static
const
__device__
uint32_t
iq3xxs_grid
[
256
]
=
{
0x04040404
,
0x04040414
,
0x04040424
,
0x04040c0c
,
0x04040c1c
,
0x04040c3e
,
0x04041404
,
0x04041414
,
0x04041c0c
,
0x04042414
,
0x04043e1c
,
0x04043e2c
,
0x040c040c
,
0x040c041c
,
0x040c0c04
,
0x040c0c14
,
0x040c140c
,
0x040c142c
,
0x040c1c04
,
0x040c1c14
,
0x040c240c
,
0x040c2c24
,
0x040c3e04
,
0x04140404
,
0x04140414
,
0x04140424
,
0x04140c0c
,
0x04141404
,
0x04141414
,
0x04141c0c
,
0x04141c1c
,
0x04141c3e
,
0x04142c0c
,
0x04142c3e
,
0x04143e2c
,
0x041c040c
,
0x041c043e
,
0x041c0c04
,
0x041c0c14
,
0x041c142c
,
0x041c3e04
,
0x04240c1c
,
0x04241c3e
,
0x04242424
,
0x04242c3e
,
0x04243e1c
,
0x04243e2c
,
0x042c040c
,
0x042c043e
,
0x042c1c14
,
0x042c2c14
,
0x04341c2c
,
0x04343424
,
0x043e0c04
,
0x043e0c24
,
0x043e0c34
,
0x043e241c
,
0x043e340c
,
0x0c04040c
,
0x0c04041c
,
0x0c040c04
,
0x0c040c14
,
0x0c04140c
,
0x0c04141c
,
0x0c041c04
,
0x0c041c14
,
0x0c041c24
,
0x0c04243e
,
0x0c042c04
,
0x0c0c0404
,
0x0c0c0414
,
0x0c0c0c0c
,
0x0c0c1404
,
0x0c0c1414
,
0x0c14040c
,
0x0c14041c
,
0x0c140c04
,
0x0c140c14
,
0x0c14140c
,
0x0c141c04
,
0x0c143e14
,
0x0c1c0404
,
0x0c1c0414
,
0x0c1c1404
,
0x0c1c1c0c
,
0x0c1c2434
,
0x0c1c3434
,
0x0c24040c
,
0x0c24042c
,
0x0c242c04
,
0x0c2c1404
,
0x0c2c1424
,
0x0c2c2434
,
0x0c2c3e0c
,
0x0c34042c
,
0x0c3e1414
,
0x0c3e2404
,
0x14040404
,
0x14040414
,
0x14040c0c
,
0x14040c1c
,
0x14041404
,
0x14041414
,
0x14041434
,
0x14041c0c
,
0x14042414
,
0x140c040c
,
0x140c041c
,
0x140c042c
,
0x140c0c04
,
0x140c0c14
,
0x140c140c
,
0x140c1c04
,
0x140c341c
,
0x140c343e
,
0x140c3e04
,
0x14140404
,
0x14140414
,
0x14140c0c
,
0x14140c3e
,
0x14141404
,
0x14141414
,
0x14141c3e
,
0x14142404
,
0x14142c2c
,
0x141c040c
,
0x141c0c04
,
0x141c0c24
,
0x141c3e04
,
0x141c3e24
,
0x14241c2c
,
0x14242c1c
,
0x142c041c
,
0x142c143e
,
0x142c240c
,
0x142c3e24
,
0x143e040c
,
0x143e041c
,
0x143e0c34
,
0x143e242c
,
0x1c04040c
,
0x1c040c04
,
0x1c040c14
,
0x1c04140c
,
0x1c04141c
,
0x1c042c04
,
0x1c04342c
,
0x1c043e14
,
0x1c0c0404
,
0x1c0c0414
,
0x1c0c1404
,
0x1c0c1c0c
,
0x1c0c2424
,
0x1c0c2434
,
0x1c14040c
,
0x1c14041c
,
0x1c140c04
,
0x1c14142c
,
0x1c142c14
,
0x1c143e14
,
0x1c1c0c0c
,
0x1c1c1c1c
,
0x1c241c04
,
0x1c24243e
,
0x1c243e14
,
0x1c2c0404
,
0x1c2c0434
,
0x1c2c1414
,
0x1c2c2c2c
,
0x1c340c24
,
0x1c341c34
,
0x1c34341c
,
0x1c3e1c1c
,
0x1c3e3404
,
0x24040424
,
0x24040c3e
,
0x24041c2c
,
0x24041c3e
,
0x24042c1c
,
0x24042c3e
,
0x240c3e24
,
0x24141404
,
0x24141c3e
,
0x24142404
,
0x24143404
,
0x24143434
,
0x241c043e
,
0x241c242c
,
0x24240424
,
0x24242c0c
,
0x24243424
,
0x242c142c
,
0x242c241c
,
0x242c3e04
,
0x243e042c
,
0x243e0c04
,
0x243e0c14
,
0x243e1c04
,
0x2c040c14
,
0x2c04240c
,
0x2c043e04
,
0x2c0c0404
,
0x2c0c0434
,
0x2c0c1434
,
0x2c0c2c2c
,
0x2c140c24
,
0x2c141c14
,
0x2c143e14
,
0x2c1c0414
,
0x2c1c2c1c
,
0x2c240c04
,
0x2c24141c
,
0x2c24143e
,
0x2c243e14
,
0x2c2c0414
,
0x2c2c1c0c
,
0x2c342c04
,
0x2c3e1424
,
0x2c3e2414
,
0x34041424
,
0x34042424
,
0x34042434
,
0x34043424
,
0x340c140c
,
0x340c340c
,
0x34140c3e
,
0x34143424
,
0x341c1c04
,
0x341c1c34
,
0x34242424
,
0x342c042c
,
0x342c2c14
,
0x34341c1c
,
0x343e041c
,
0x343e140c
,
0x3e04041c
,
0x3e04042c
,
0x3e04043e
,
0x3e040c04
,
0x3e041c14
,
0x3e042c14
,
0x3e0c1434
,
0x3e0c2404
,
0x3e140c14
,
0x3e14242c
,
0x3e142c14
,
0x3e1c0404
,
0x3e1c0c2c
,
0x3e1c1c1c
,
0x3e1c3404
,
0x3e24140c
,
0x3e24240c
,
0x3e2c0404
,
0x3e2c0414
,
0x3e2c1424
,
0x3e341c04
,
};
static
const
__device__
uint32_t
iq3xs_grid
[
512
]
=
{
0x04040404
,
0x0404040c
,
0x04040414
,
0x0404042c
,
0x0404043e
,
0x04040c04
,
0x04040c0c
,
0x04040c14
,
0x04040c24
,
0x04040c34
,
0x04041404
,
0x0404140c
,
0x0404142c
,
0x04041c1c
,
0x04042404
,
0x04042414
,
0x0404242c
,
0x0404243e
,
0x04042c0c
,
0x04042c1c
,
0x04043404
,
0x04043414
,
0x04043e0c
,
0x04043e24
,
0x04043e3e
,
0x040c0404
,
0x040c040c
,
0x040c0414
,
0x040c0424
,
0x040c0c04
,
0x040c0c0c
,
0x040c0c2c
,
0x040c1404
,
0x040c141c
,
0x040c143e
,
0x040c1c0c
,
0x040c1c2c
,
0x040c2424
,
0x040c340c
,
0x040c342c
,
0x040c3e14
,
0x04140404
,
0x0414040c
,
0x0414042c
,
0x0414043e
,
0x04140c04
,
0x04140c1c
,
0x04140c34
,
0x0414140c
,
0x0414142c
,
0x04141c04
,
0x04141c24
,
0x04142414
,
0x0414242c
,
0x0414243e
,
0x04142c0c
,
0x04142c1c
,
0x04143e04
,
0x04143e1c
,
0x041c041c
,
0x041c0c0c
,
0x041c0c2c
,
0x041c1404
,
0x041c1414
,
0x041c1c0c
,
0x041c1c1c
,
0x041c1c34
,
0x041c2424
,
0x041c2c04
,
0x041c2c14
,
0x041c343e
,
0x041c3e0c
,
0x041c3e2c
,
0x04240404
,
0x04240c1c
,
0x04240c3e
,
0x0424140c
,
0x04241424
,
0x04241c14
,
0x04242404
,
0x0424241c
,
0x04242c0c
,
0x04243e04
,
0x042c0414
,
0x042c0424
,
0x042c1404
,
0x042c1414
,
0x042c1434
,
0x042c1c1c
,
0x042c240c
,
0x042c242c
,
0x042c243e
,
0x042c3434
,
0x042c3e1c
,
0x04340434
,
0x04340c0c
,
0x04340c1c
,
0x04341c0c
,
0x04342c14
,
0x04343e0c
,
0x043e0404
,
0x043e0414
,
0x043e0424
,
0x043e1404
,
0x043e1414
,
0x043e1434
,
0x043e1c1c
,
0x043e2c04
,
0x043e2c24
,
0x0c040404
,
0x0c04040c
,
0x0c040414
,
0x0c040424
,
0x0c040c04
,
0x0c040c0c
,
0x0c040c1c
,
0x0c040c2c
,
0x0c040c3e
,
0x0c041404
,
0x0c041414
,
0x0c041c0c
,
0x0c041c24
,
0x0c041c34
,
0x0c042c24
,
0x0c042c34
,
0x0c04340c
,
0x0c043e14
,
0x0c0c0404
,
0x0c0c040c
,
0x0c0c041c
,
0x0c0c0434
,
0x0c0c0c04
,
0x0c0c0c24
,
0x0c0c140c
,
0x0c0c1c04
,
0x0c0c1c1c
,
0x0c0c240c
,
0x0c0c2c04
,
0x0c0c2c14
,
0x0c0c3e04
,
0x0c0c3e34
,
0x0c140404
,
0x0c140c14
,
0x0c140c2c
,
0x0c140c3e
,
0x0c141404
,
0x0c141424
,
0x0c141c14
,
0x0c142404
,
0x0c14241c
,
0x0c142c2c
,
0x0c143404
,
0x0c143e14
,
0x0c1c040c
,
0x0c1c0424
,
0x0c1c043e
,
0x0c1c0c04
,
0x0c1c0c1c
,
0x0c1c140c
,
0x0c1c143e
,
0x0c1c1c04
,
0x0c1c1c24
,
0x0c1c240c
,
0x0c1c3414
,
0x0c1c3e04
,
0x0c24041c
,
0x0c24042c
,
0x0c240c14
,
0x0c240c24
,
0x0c241c0c
,
0x0c241c1c
,
0x0c242414
,
0x0c242434
,
0x0c242c04
,
0x0c242c24
,
0x0c2c040c
,
0x0c2c0c04
,
0x0c2c0c1c
,
0x0c2c140c
,
0x0c2c1c04
,
0x0c2c1c14
,
0x0c2c2c0c
,
0x0c341404
,
0x0c341424
,
0x0c34143e
,
0x0c342424
,
0x0c342434
,
0x0c3e040c
,
0x0c3e041c
,
0x0c3e0c04
,
0x0c3e0c14
,
0x0c3e140c
,
0x0c3e1c2c
,
0x0c3e240c
,
0x0c3e3414
,
0x0c3e3e04
,
0x14040404
,
0x1404040c
,
0x1404041c
,
0x1404042c
,
0x1404043e
,
0x14040c04
,
0x14040c14
,
0x14040c24
,
0x14040c34
,
0x1404140c
,
0x1404141c
,
0x1404143e
,
0x14041c04
,
0x14041c14
,
0x1404240c
,
0x1404241c
,
0x1404242c
,
0x14042c04
,
0x14042c14
,
0x1404343e
,
0x14043e04
,
0x14043e1c
,
0x14043e2c
,
0x140c0404
,
0x140c0414
,
0x140c0c04
,
0x140c0c1c
,
0x140c0c3e
,
0x140c1414
,
0x140c142c
,
0x140c1c0c
,
0x140c1c24
,
0x140c2414
,
0x140c2c0c
,
0x1414040c
,
0x14140424
,
0x1414043e
,
0x1414140c
,
0x1414141c
,
0x14141c04
,
0x14141c3e
,
0x1414240c
,
0x14142c1c
,
0x14142c3e
,
0x14143e0c
,
0x14143e24
,
0x141c0404
,
0x141c0414
,
0x141c042c
,
0x141c0c0c
,
0x141c1414
,
0x141c1424
,
0x141c1c0c
,
0x141c1c1c
,
0x141c2414
,
0x141c2c04
,
0x141c3434
,
0x1424040c
,
0x1424043e
,
0x14241404
,
0x1424141c
,
0x14241c14
,
0x14241c2c
,
0x1424240c
,
0x14243e14
,
0x14243e2c
,
0x142c0424
,
0x142c0c0c
,
0x142c1414
,
0x142c1c3e
,
0x142c2404
,
0x142c2c1c
,
0x142c3e04
,
0x14340404
,
0x14340414
,
0x1434043e
,
0x1434140c
,
0x14342c2c
,
0x1434340c
,
0x143e042c
,
0x143e0c0c
,
0x143e1434
,
0x143e1c04
,
0x143e241c
,
0x143e2c04
,
0x1c040414
,
0x1c040c0c
,
0x1c040c1c
,
0x1c040c2c
,
0x1c040c3e
,
0x1c041414
,
0x1c041c0c
,
0x1c041c1c
,
0x1c041c2c
,
0x1c042414
,
0x1c042424
,
0x1c04243e
,
0x1c042c0c
,
0x1c04341c
,
0x1c043e0c
,
0x1c0c040c
,
0x1c0c041c
,
0x1c0c042c
,
0x1c0c0c24
,
0x1c0c140c
,
0x1c0c141c
,
0x1c0c2404
,
0x1c0c3404
,
0x1c0c3e14
,
0x1c0c3e34
,
0x1c140404
,
0x1c140c14
,
0x1c141404
,
0x1c141c14
,
0x1c141c24
,
0x1c142c04
,
0x1c1c040c
,
0x1c1c0c04
,
0x1c1c0c24
,
0x1c1c140c
,
0x1c1c141c
,
0x1c1c143e
,
0x1c1c1c04
,
0x1c1c240c
,
0x1c1c241c
,
0x1c1c243e
,
0x1c1c2c2c
,
0x1c1c3e1c
,
0x1c24041c
,
0x1c240c0c
,
0x1c240c34
,
0x1c241414
,
0x1c241c0c
,
0x1c242c14
,
0x1c243404
,
0x1c243424
,
0x1c2c040c
,
0x1c2c0c04
,
0x1c2c0c14
,
0x1c2c142c
,
0x1c2c1c14
,
0x1c2c2424
,
0x1c2c2c34
,
0x1c2c3e1c
,
0x1c340c34
,
0x1c34240c
,
0x1c3e040c
,
0x1c3e041c
,
0x1c3e1404
,
0x1c3e1414
,
0x1c3e1c2c
,
0x24040404
,
0x24040424
,
0x24040c14
,
0x24041404
,
0x24041424
,
0x2404143e
,
0x24041c14
,
0x2404240c
,
0x24042c04
,
0x24043e04
,
0x240c0414
,
0x240c043e
,
0x240c0c0c
,
0x240c0c1c
,
0x240c1414
,
0x240c1c04
,
0x240c1c2c
,
0x240c241c
,
0x240c2c0c
,
0x240c2c2c
,
0x2414040c
,
0x2414041c
,
0x24140c04
,
0x24140c2c
,
0x2414140c
,
0x24141c1c
,
0x24142404
,
0x24142c3e
,
0x24143414
,
0x24143e04
,
0x241c0424
,
0x241c0c0c
,
0x241c0c1c
,
0x241c1404
,
0x241c1414
,
0x241c1c0c
,
0x241c1c2c
,
0x24240404
,
0x24240414
,
0x24241424
,
0x24241c3e
,
0x24242404
,
0x24243e0c
,
0x242c042c
,
0x242c043e
,
0x242c140c
,
0x242c3414
,
0x24340c1c
,
0x24341c24
,
0x24343404
,
0x243e0c04
,
0x243e0c2c
,
0x243e1c04
,
0x243e241c
,
0x243e2c0c
,
0x2c040414
,
0x2c040c04
,
0x2c040c24
,
0x2c041414
,
0x2c042404
,
0x2c042424
,
0x2c04243e
,
0x2c042c14
,
0x2c043434
,
0x2c043e24
,
0x2c0c040c
,
0x2c0c041c
,
0x2c0c042c
,
0x2c0c0c14
,
0x2c0c140c
,
0x2c0c1c14
,
0x2c0c3e14
,
0x2c140404
,
0x2c140c0c
,
0x2c14141c
,
0x2c141c04
,
0x2c141c34
,
0x2c142c1c
,
0x2c1c0414
,
0x2c1c043e
,
0x2c1c0c04
,
0x2c1c143e
,
0x2c1c2424
,
0x2c1c2c0c
,
0x2c1c342c
,
0x2c1c3e1c
,
0x2c24040c
,
0x2c240424
,
0x2c241404
,
0x2c241c14
,
0x2c242434
,
0x2c2c0c14
,
0x2c2c1434
,
0x2c2c2c0c
,
0x2c2c2c1c
,
0x2c342414
,
0x2c3e0414
,
0x2c3e0424
,
0x2c3e1414
,
0x34040c0c
,
0x34040c1c
,
0x34040c2c
,
0x34041c0c
,
0x34041c1c
,
0x34043404
,
0x340c0404
,
0x340c1404
,
0x340c143e
,
0x340c3424
,
0x34140c14
,
0x34141c24
,
0x34142414
,
0x34142c2c
,
0x34143414
,
0x34143e04
,
0x341c0404
,
0x341c0c24
,
0x341c140c
,
0x341c2404
,
0x3424142c
,
0x3424241c
,
0x34243414
,
0x342c0404
,
0x342c041c
,
0x342c1c24
,
0x342c3404
,
0x3434042c
,
0x34342404
,
0x343e0c0c
,
0x343e0c1c
,
0x3e040404
,
0x3e040424
,
0x3e04043e
,
0x3e041404
,
0x3e041414
,
0x3e041c34
,
0x3e042404
,
0x3e042c24
,
0x3e043414
,
0x3e0c0414
,
0x3e0c0c0c
,
0x3e0c1424
,
0x3e0c241c
,
0x3e0c242c
,
0x3e14040c
,
0x3e140424
,
0x3e140c04
,
0x3e140c34
,
0x3e14140c
,
0x3e141c04
,
0x3e142c0c
,
0x3e1c0414
,
0x3e1c1c14
,
0x3e1c1c2c
,
0x3e1c2c1c
,
0x3e24040c
,
0x3e24042c
,
0x3e240c1c
,
0x3e241404
,
0x3e242c04
,
0x3e2c1414
,
0x3e2c2414
,
0x3e340414
,
0x3e341c0c
,
0x3e3e0404
,
};
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
static
const
__device__
uint64_t
iq1s_grid_gpu
[
2048
]
=
{
0x00000000
,
0x00000002
,
0x00000101
,
0x00000200
,
0x00000202
,
0x00010001
,
0x00010101
,
0x00020000
,
0x00020002
,
0x00020200
,
0x00020202
,
0x01000101
,
0x01010001
,
0x01010100
,
0x01010102
,
0x01020101
,
0x02000000
,
0x02000002
,
0x02000200
,
0x02000202
,
0x02010101
,
0x02020000
,
0x02020002
,
0x02020200
,
0x02020202
,
0x00000110
,
0x00000111
,
0x00010011
,
0x00010110
,
0x00010112
,
0x00010211
,
0x00010212
,
0x00020111
,
0x01000011
,
0x01000112
,
0x01000211
,
0x01010012
,
0x01010111
,
0x01010212
,
0x01020011
,
0x01020110
,
0x01020112
,
0x01020210
,
0x02000111
,
0x02010011
,
0x02010110
,
0x02010112
,
0x02020111
,
0x00000020
,
0x00000022
,
0x00000220
,
0x00000222
,
0x00010121
,
0x00020020
,
0x00020022
,
0x00020220
,
0x00020222
,
0x01000121
,
0x01010021
,
0x01010221
,
0x01020120
,
0x01020221
,
0x02000020
,
0x02000022
,
0x02000220
,
0x02000222
,
0x02010021
,
0x02010121
,
0x02010221
,
0x02020020
,
0x02020022
,
0x02020220
,
0x02020222
,
0x00011001
,
0x00011100
,
0x00011102
,
0x00021101
,
0x01001001
,
0x01001201
,
0x01011101
,
0x01011202
,
0x01021100
,
0x01021101
,
0x02011001
,
0x02011201
,
0x02021101
,
0x00001011
,
0x00001110
,
0x00001111
,
0x00001112
,
0x00011111
,
0x00011210
,
0x00011212
,
0x00021211
,
0x01001010
,
0x01001111
,
0x01001212
,
0x01011010
,
0x01011011
,
0x01011110
,
0x01011111
,
0x01011112
,
0x01011211
,
0x01021010
,
0x01021012
,
0x01021111
,
0x01021210
,
0x01021212
,
0x02001011
,
0x02011011
,
0x02011111
,
0x02011210
,
0x02011212
,
0x02021011
,
0x02021110
,
0x02021111
,
0x02021112
,
0x02021211
,
0x00011120
,
0x00011221
,
0x01001021
,
0x01001120
,
0x01011020
,
0x01011022
,
0x01011121
,
0x01011220
,
0x01021020
,
0x01021021
,
0x01021122
,
0x01021221
,
0x02001121
,
0x02011021
,
0x02011120
,
0x02011221
,
0x00002000
,
0x00002002
,
0x00002200
,
0x00002202
,
0x00012101
,
0x00022000
,
0x00022002
,
0x00022200
,
0x00022202
,
0x01002101
,
0x01012001
,
0x01012102
,
0x01022101
,
0x02002000
,
0x02002002
,
0x02002200
,
0x02002202
,
0x02012101
,
0x02022000
,
0x02022002
,
0x02022200
,
0x02022202
,
0x00002111
,
0x00012011
,
0x00012110
,
0x00012211
,
0x00022110
,
0x00022111
,
0x01002011
,
0x01012010
,
0x01012011
,
0x01012111
,
0x01022011
,
0x01022110
,
0x01022211
,
0x02012011
,
0x02012110
,
0x02012112
,
0x02012211
,
0x02022111
,
0x00002020
,
0x00002022
,
0x00002220
,
0x00002222
,
0x00012121
,
0x00022020
,
0x00022022
,
0x00022220
,
0x00022222
,
0x01002121
,
0x01012021
,
0x01012221
,
0x01022021
,
0x01022121
,
0x02002020
,
0x02002022
,
0x02002121
,
0x02002220
,
0x02002222
,
0x02012121
,
0x02022020
,
0x02022022
,
0x02022220
,
0x02022222
,
0x00110000
,
0x00110001
,
0x00110100
,
0x00110201
,
0x00120100
,
0x00120101
,
0x01100001
,
0x01100100
,
0x01110000
,
0x01110101
,
0x01110200
,
0x01120001
,
0x01120100
,
0x01120101
,
0x01120201
,
0x02110001
,
0x02110100
,
0x02110102
,
0x02120001
,
0x02120101
,
0x00100011
,
0x00100110
,
0x00100112
,
0x00100211
,
0x00110010
,
0x00110012
,
0x00110111
,
0x00110210
,
0x00120011
,
0x00120110
,
0x00120211
,
0x01100111
,
0x01100212
,
0x01110010
,
0x01110011
,
0x01110012
,
0x01110110
,
0x01110111
,
0x01110112
,
0x01110211
,
0x01120010
,
0x01120111
,
0x02100110
,
0x02110012
,
0x02110111
,
0x02120011
,
0x02120110
,
0x00110021
,
0x00110120
,
0x00110122
,
0x00120121
,
0x01100020
,
0x01100122
,
0x01100221
,
0x01110022
,
0x01110121
,
0x01110220
,
0x01110222
,
0x01120120
,
0x01120122
,
0x02100121
,
0x02110021
,
0x02110120
,
0x02110122
,
0x02120121
,
0x00101001
,
0x00101102
,
0x00101201
,
0x00111100
,
0x00111101
,
0x00111200
,
0x00111201
,
0x00121001
,
0x00121102
,
0x01101001
,
0x01101101
,
0x01101102
,
0x01101200
,
0x01101202
,
0x01111001
,
0x01111100
,
0x01111101
,
0x01111102
,
0x01111201
,
0x01121002
,
0x01121101
,
0x01121200
,
0x02101100
,
0x02101201
,
0x02111000
,
0x02111100
,
0x02111101
,
0x02111200
,
0x02111201
,
0x02111202
,
0x02121001
,
0x02121100
,
0x02121101
,
0x02121201
,
0x00101012
,
0x00101111
,
0x00101212
,
0x00111011
,
0x00111110
,
0x00111111
,
0x00111112
,
0x00111211
,
0x00121010
,
0x00121012
,
0x00121111
,
0x00121210
,
0x00121212
,
0x01101011
,
0x01101110
,
0x01101111
,
0x01101112
,
0x01111011
,
0x01111012
,
0x01111110
,
0x01111111
,
0x01111112
,
0x01111211
,
0x01111212
,
0x01121011
,
0x01121110
,
0x01121111
,
0x01121112
,
0x01121211
,
0x02101010
,
0x02101012
,
0x02101110
,
0x02101111
,
0x02101210
,
0x02101212
,
0x02111010
,
0x02111011
,
0x02111110
,
0x02111111
,
0x02111112
,
0x02111211
,
0x02111212
,
0x02121010
,
0x02121012
,
0x02121111
,
0x00101021
,
0x00101120
,
0x00101121
,
0x00101122
,
0x00111121
,
0x00111122
,
0x00111220
,
0x00111222
,
0x00121021
,
0x00121122
,
0x01101020
,
0x01101022
,
0x01101120
,
0x01101121
,
0x01101220
,
0x01101222
,
0x01111021
,
0x01111121
,
0x01111122
,
0x01111220
,
0x01111221
,
0x01121021
,
0x01121120
,
0x01121121
,
0x01121220
,
0x01121221
,
0x01121222
,
0x02101122
,
0x02101222
,
0x02111022
,
0x02111121
,
0x02121120
,
0x02121221
,
0x00112001
,
0x00112102
,
0x00122101
,
0x01102001
,
0x01102100
,
0x01102102
,
0x01102201
,
0x01112000
,
0x01112101
,
0x01112200
,
0x01112202
,
0x01122000
,
0x01122001
,
0x01122100
,
0x01122102
,
0x01122201
,
0x02102101
,
0x02112001
,
0x02112100
,
0x02122101
,
0x00112010
,
0x00112012
,
0x00112111
,
0x00112212
,
0x00122011
,
0x00122111
,
0x01102012
,
0x01102110
,
0x01102111
,
0x01102210
,
0x01112011
,
0x01112110
,
0x01112111
,
0x01112112
,
0x01112211
,
0x01112212
,
0x01122010
,
0x01122111
,
0x01122212
,
0x02102211
,
0x02112011
,
0x02112012
,
0x02112111
,
0x02112210
,
0x02122011
,
0x02122112
,
0x02122211
,
0x00102221
,
0x00112122
,
0x00122120
,
0x00122122
,
0x01102120
,
0x01102122
,
0x01102221
,
0x01112020
,
0x01112022
,
0x01112121
,
0x01112220
,
0x01122021
,
0x01122122
,
0x01122221
,
0x02102121
,
0x02112021
,
0x02112122
,
0x02112222
,
0x00200000
,
0x00200002
,
0x00200200
,
0x00200202
,
0x00210101
,
0x00220000
,
0x00220002
,
0x00220101
,
0x00220200
,
0x00220202
,
0x01200101
,
0x01210001
,
0x01210201
,
0x01220001
,
0x01220101
,
0x02200000
,
0x02200002
,
0x02200200
,
0x02200202
,
0x02210101
,
0x02220000
,
0x02220002
,
0x02220101
,
0x02220200
,
0x02220202
,
0x00200111
,
0x00210011
,
0x00210110
,
0x00210211
,
0x00220111
,
0x01200012
,
0x01200110
,
0x01200211
,
0x01210111
,
0x01210210
,
0x01210212
,
0x01220011
,
0x01220110
,
0x01220111
,
0x01220112
,
0x02200111
,
0x02210010
,
0x02210112
,
0x02210211
,
0x02220111
,
0x00200021
,
0x00200220
,
0x00200222
,
0x00210021
,
0x00210121
,
0x00220020
,
0x00220022
,
0x00220220
,
0x00220222
,
0x01200121
,
0x01210021
,
0x01210122
,
0x01210221
,
0x01220121
,
0x02200021
,
0x02200220
,
0x02200222
,
0x02210021
,
0x02210121
,
0x02220020
,
0x02220022
,
0x02220220
,
0x02220222
,
0x00201101
,
0x00211100
,
0x00211102
,
0x00211201
,
0x00221101
,
0x01201100
,
0x01201101
,
0x01201102
,
0x01201201
,
0x01211002
,
0x01211101
,
0x01211200
,
0x01211202
,
0x01221102
,
0x02201101
,
0x02211001
,
0x02211100
,
0x02211201
,
0x02221001
,
0x02221101
,
0x00201211
,
0x00211111
,
0x00221011
,
0x00221211
,
0x01201010
,
0x01201111
,
0x01201210
,
0x01211011
,
0x01211110
,
0x01211111
,
0x01211211
,
0x01221012
,
0x01221111
,
0x01221210
,
0x02201211
,
0x02211010
,
0x02211110
,
0x02211111
,
0x02211210
,
0x02211212
,
0x02221011
,
0x02221110
,
0x02221112
,
0x02221211
,
0x00201121
,
0x00211020
,
0x00211022
,
0x00211221
,
0x00221121
,
0x01201021
,
0x01201221
,
0x01211121
,
0x01221020
,
0x01221021
,
0x01221221
,
0x02201120
,
0x02201122
,
0x02211020
,
0x02211222
,
0x00202000
,
0x00202002
,
0x00202200
,
0x00202202
,
0x00212101
,
0x00222000
,
0x00222002
,
0x00222200
,
0x00222202
,
0x01202101
,
0x01212001
,
0x01212100
,
0x01222101
,
0x02202000
,
0x02202002
,
0x02202200
,
0x02202202
,
0x02222000
,
0x02222002
,
0x02222200
,
0x02222202
,
0x00202211
,
0x00212011
,
0x00212110
,
0x00212211
,
0x00222111
,
0x01202112
,
0x01202211
,
0x01212012
,
0x01212111
,
0x01222011
,
0x01222110
,
0x01222112
,
0x01222211
,
0x02202111
,
0x02212010
,
0x02212112
,
0x02212211
,
0x02222110
,
0x02222111
,
0x00202020
,
0x00202022
,
0x00202220
,
0x00202222
,
0x00222020
,
0x00222022
,
0x00222220
,
0x00222222
,
0x01202121
,
0x01212021
,
0x01212122
,
0x01212221
,
0x01222121
,
0x02202020
,
0x02202022
,
0x02202220
,
0x02202222
,
0x02212121
,
0x02222020
,
0x02222022
,
0x02222220
,
0x02222222
,
0x10000101
,
0x10010001
,
0x10010102
,
0x10020101
,
0x11000201
,
0x11010002
,
0x11010101
,
0x11010200
,
0x11010202
,
0x11020001
,
0x11020100
,
0x11020102
,
0x12010100
,
0x12010201
,
0x12020001
,
0x12020102
,
0x10000010
,
0x10000011
,
0x10000110
,
0x10000112
,
0x10000211
,
0x10010012
,
0x10010111
,
0x10010112
,
0x10010210
,
0x10010212
,
0x10020011
,
0x10020112
,
0x10020211
,
0x11000111
,
0x11000210
,
0x11000212
,
0x11010011
,
0x11010110
,
0x11010111
,
0x11010112
,
0x11010211
,
0x11010212
,
0x11020111
,
0x11020210
,
0x11020212
,
0x12000011
,
0x12000110
,
0x12000112
,
0x12010010
,
0x12010012
,
0x12010111
,
0x12020010
,
0x12020011
,
0x12020012
,
0x10000121
,
0x10010021
,
0x10010120
,
0x10010122
,
0x10020121
,
0x11000021
,
0x11010022
,
0x11010121
,
0x11010222
,
0x11020120
,
0x11020221
,
0x12000221
,
0x12010120
,
0x12020121
,
0x10001001
,
0x10011101
,
0x10011201
,
0x10021201
,
0x11001101
,
0x11001200
,
0x11001202
,
0x11011001
,
0x11011100
,
0x11011101
,
0x11011102
,
0x11021001
,
0x11021002
,
0x11021101
,
0x11021200
,
0x11021202
,
0x12001001
,
0x12001102
,
0x12001201
,
0x12011000
,
0x12011002
,
0x12011101
,
0x12021000
,
0x12021001
,
0x12021201
,
0x10001011
,
0x10001012
,
0x10001111
,
0x10001212
,
0x10011011
,
0x10011110
,
0x10011111
,
0x10011112
,
0x10011211
,
0x10021010
,
0x10021111
,
0x10021212
,
0x11001011
,
0x11001110
,
0x11001111
,
0x11001112
,
0x11001211
,
0x11011010
,
0x11011011
,
0x11011110
,
0x11011111
,
0x11011112
,
0x11011210
,
0x11011211
,
0x11021011
,
0x11021110
,
0x11021111
,
0x11021112
,
0x11021211
,
0x12001012
,
0x12001110
,
0x12001111
,
0x12001210
,
0x12011011
,
0x12011110
,
0x12011111
,
0x12011112
,
0x12011211
,
0x12011212
,
0x12021111
,
0x12021210
,
0x12021212
,
0x10001021
,
0x10001121
,
0x10001221
,
0x10011120
,
0x10011121
,
0x10011220
,
0x10011222
,
0x10021021
,
0x10021120
,
0x10021221
,
0x11001020
,
0x11001022
,
0x11001121
,
0x11001220
,
0x11011020
,
0x11011021
,
0x11011022
,
0x11011121
,
0x11011122
,
0x11011221
,
0x11021022
,
0x11021121
,
0x11021220
,
0x12001021
,
0x12001121
,
0x12001222
,
0x12011120
,
0x12011121
,
0x12021021
,
0x12021120
,
0x12021122
,
0x10002101
,
0x10012001
,
0x10012101
,
0x10012202
,
0x10022101
,
0x11002002
,
0x11002201
,
0x11012000
,
0x11012101
,
0x11012200
,
0x11022001
,
0x11022100
,
0x11022102
,
0x11022201
,
0x12002101
,
0x12012001
,
0x12012100
,
0x12012102
,
0x12012201
,
0x12022101
,
0x10002011
,
0x10002111
,
0x10002112
,
0x10002212
,
0x10012010
,
0x10012110
,
0x10012111
,
0x10012210
,
0x10022011
,
0x10022110
,
0x10022112
,
0x11002010
,
0x11002111
,
0x11002212
,
0x11012011
,
0x11012012
,
0x11012110
,
0x11012111
,
0x11012112
,
0x11012211
,
0x11022010
,
0x11022012
,
0x11022111
,
0x11022112
,
0x11022212
,
0x12002112
,
0x12002211
,
0x12012012
,
0x12012111
,
0x12012112
,
0x12012210
,
0x12022011
,
0x12022110
,
0x12022112
,
0x12022211
,
0x10012122
,
0x11002120
,
0x11002122
,
0x11002221
,
0x11012121
,
0x11012220
,
0x11012222
,
0x11022120
,
0x11022221
,
0x12012120
,
0x12022121
,
0x10100001
,
0x10100100
,
0x10100101
,
0x10100102
,
0x10100201
,
0x10110002
,
0x10110101
,
0x10110202
,
0x10120001
,
0x10120100
,
0x10120201
,
0x11100000
,
0x11100101
,
0x11100200
,
0x11110001
,
0x11110100
,
0x11110101
,
0x11110102
,
0x11110201
,
0x11120101
,
0x11120200
,
0x12100102
,
0x12100201
,
0x12110101
,
0x12110200
,
0x12120000
,
0x12120001
,
0x12120102
,
0x12120201
,
0x10100111
,
0x10100210
,
0x10100211
,
0x10100212
,
0x10110011
,
0x10110110
,
0x10110111
,
0x10110112
,
0x10110210
,
0x10110211
,
0x10120010
,
0x10120111
,
0x10120112
,
0x10120210
,
0x10120212
,
0x11100011
,
0x11100110
,
0x11100111
,
0x11100112
,
0x11100211
,
0x11110010
,
0x11110011
,
0x11110012
,
0x11110110
,
0x11110111
,
0x11110112
,
0x11110210
,
0x11110211
,
0x11110212
,
0x11120011
,
0x11120110
,
0x11120111
,
0x11120112
,
0x11120211
,
0x12100012
,
0x12100111
,
0x12110011
,
0x12110110
,
0x12110111
,
0x12110112
,
0x12110211
,
0x12120010
,
0x12120111
,
0x12120212
,
0x10100021
,
0x10100122
,
0x10110022
,
0x10110121
,
0x10110222
,
0x10120021
,
0x10120120
,
0x11100022
,
0x11100121
,
0x11100222
,
0x11110021
,
0x11110120
,
0x11110121
,
0x11110122
,
0x11110221
,
0x11120022
,
0x11120121
,
0x12100121
,
0x12110020
,
0x12110022
,
0x12110121
,
0x12110221
,
0x12110222
,
0x12120120
,
0x10101100
,
0x10101101
,
0x10111001
,
0x10111100
,
0x10111101
,
0x10111102
,
0x10111200
,
0x10111201
,
0x10121001
,
0x10121101
,
0x10121200
,
0x10121202
,
0x11101001
,
0x11101100
,
0x11101101
,
0x11101102
,
0x11101201
,
0x11101202
,
0x11111000
,
0x11111001
,
0x11111100
,
0x11111101
,
0x11111102
,
0x11111200
,
0x11111201
,
0x11111202
,
0x11121001
,
0x11121002
,
0x11121100
,
0x11121101
,
0x11121102
,
0x11121201
,
0x12101000
,
0x12101200
,
0x12101202
,
0x12111001
,
0x12111100
,
0x12111101
,
0x12111102
,
0x12111201
,
0x12121001
,
0x12121100
,
0x12121101
,
0x12121202
,
0x10101011
,
0x10101012
,
0x10101110
,
0x10101111
,
0x10101112
,
0x10101211
,
0x10111010
,
0x10111011
,
0x10111012
,
0x10111110
,
0x10111111
,
0x10111112
,
0x10111211
,
0x10111212
,
0x10121011
,
0x10121110
,
0x10121111
,
0x10121112
,
0x10121211
,
0x11101010
,
0x11101011
,
0x11101012
,
0x11101110
,
0x11101111
,
0x11101112
,
0x11101210
,
0x11101211
,
0x11111010
,
0x11111011
,
0x11111012
,
0x11111110
,
0x11111111
,
0x11111112
,
0x11111210
,
0x11111211
,
0x11111212
,
0x11121010
,
0x11121011
,
0x11121110
,
0x11121111
,
0x11121112
,
0x11121210
,
0x11121211
,
0x11121212
,
0x12101011
,
0x12101110
,
0x12101111
,
0x12101211
,
0x12101212
,
0x12111010
,
0x12111011
,
0x12111110
,
0x12111111
,
0x12111112
,
0x12111210
,
0x12111211
,
0x12121011
,
0x12121110
,
0x12121111
,
0x12121112
,
0x12121211
,
0x10101020
,
0x10101021
,
0x10101022
,
0x10101120
,
0x10101122
,
0x10101220
,
0x10101221
,
0x10111021
,
0x10111120
,
0x10111121
,
0x10111220
,
0x10111221
,
0x10121020
,
0x10121021
,
0x10121022
,
0x10121120
,
0x10121121
,
0x10121122
,
0x10121220
,
0x10121221
,
0x11101021
,
0x11101121
,
0x11101122
,
0x11101220
,
0x11101221
,
0x11101222
,
0x11111020
,
0x11111021
,
0x11111022
,
0x11111120
,
0x11111121
,
0x11111122
,
0x11111220
,
0x11111221
,
0x11111222
,
0x11121021
,
0x11121120
,
0x11121121
,
0x11121221
,
0x12101022
,
0x12101121
,
0x12101122
,
0x12101220
,
0x12101221
,
0x12101222
,
0x12111021
,
0x12111121
,
0x12111222
,
0x12121022
,
0x12121121
,
0x12121122
,
0x12121220
,
0x12121221
,
0x10102100
,
0x10102101
,
0x10102102
,
0x10102201
,
0x10112000
,
0x10112101
,
0x10112200
,
0x10122001
,
0x10122202
,
0x11102101
,
0x11102200
,
0x11102202
,
0x11112001
,
0x11112100
,
0x11112101
,
0x11112102
,
0x11112200
,
0x11112201
,
0x11122000
,
0x11122002
,
0x11122100
,
0x11122101
,
0x12102002
,
0x12102201
,
0x12112000
,
0x12112002
,
0x12112101
,
0x12112200
,
0x12122001
,
0x12122201
,
0x10102011
,
0x10102012
,
0x10102111
,
0x10102212
,
0x10112011
,
0x10112110
,
0x10112111
,
0x10112112
,
0x10112211
,
0x10122111
,
0x11102011
,
0x11102110
,
0x11102111
,
0x11102112
,
0x11102211
,
0x11112010
,
0x11112011
,
0x11112012
,
0x11112110
,
0x11112111
,
0x11112112
,
0x11112210
,
0x11112211
,
0x11112212
,
0x11122011
,
0x11122110
,
0x11122111
,
0x11122112
,
0x11122211
,
0x12102011
,
0x12102111
,
0x12102211
,
0x12112011
,
0x12112110
,
0x12112111
,
0x12112112
,
0x12112210
,
0x12112211
,
0x12122111
,
0x10102120
,
0x10102220
,
0x10112121
,
0x10112222
,
0x10122020
,
0x10122121
,
0x10122122
,
0x10122221
,
0x11102121
,
0x11102220
,
0x11102221
,
0x11112021
,
0x11112121
,
0x11112122
,
0x11112220
,
0x11112221
,
0x11122022
,
0x11122121
,
0x11122220
,
0x11122222
,
0x12102021
,
0x12102222
,
0x12112022
,
0x12112121
,
0x12112122
,
0x12112220
,
0x12112222
,
0x12122021
,
0x10200101
,
0x10210100
,
0x10210102
,
0x10210201
,
0x10220101
,
0x11200100
,
0x11210000
,
0x11210101
,
0x11210102
,
0x11210200
,
0x11210202
,
0x11220001
,
0x11220100
,
0x11220102
,
0x11220201
,
0x12200001
,
0x12210102
,
0x12220101
,
0x10200011
,
0x10200110
,
0x10200112
,
0x10200211
,
0x10210012
,
0x10210111
,
0x10220011
,
0x10220012
,
0x10220112
,
0x10220211
,
0x11200111
,
0x11200211
,
0x11210011
,
0x11210111
,
0x11210112
,
0x11210211
,
0x11220111
,
0x11220112
,
0x11220212
,
0x12200110
,
0x12200212
,
0x12210012
,
0x12210111
,
0x12220011
,
0x12220112
,
0x12220211
,
0x10210021
,
0x10210122
,
0x10210221
,
0x11200020
,
0x11200021
,
0x11200122
,
0x11210121
,
0x11210122
,
0x11210220
,
0x11220020
,
0x12200121
,
0x12210021
,
0x12210122
,
0x12220121
,
0x10211001
,
0x10211002
,
0x10211101
,
0x10211102
,
0x10211202
,
0x10221001
,
0x10221102
,
0x10221201
,
0x11201000
,
0x11201002
,
0x11201101
,
0x11201200
,
0x11201202
,
0x11211001
,
0x11211100
,
0x11211101
,
0x11211102
,
0x11211201
,
0x11211202
,
0x11221000
,
0x11221002
,
0x11221101
,
0x12201100
,
0x12201101
,
0x12201201
,
0x12211000
,
0x12211002
,
0x12211100
,
0x12211101
,
0x12211102
,
0x12211200
,
0x12211202
,
0x12221001
,
0x12221100
,
0x12221201
,
0x10201111
,
0x10201210
,
0x10201212
,
0x10211011
,
0x10211111
,
0x10211112
,
0x10211211
,
0x11201110
,
0x11201111
,
0x11201112
,
0x11201211
,
0x11211010
,
0x11211011
,
0x11211110
,
0x11211111
,
0x11211112
,
0x11211211
,
0x11221011
,
0x11221110
,
0x11221111
,
0x11221112
,
0x11221211
,
0x12201112
,
0x12201211
,
0x12201212
,
0x12211011
,
0x12211111
,
0x12211112
,
0x12211211
,
0x12211212
,
0x12221012
,
0x12221111
,
0x12221112
,
0x12221210
,
0x10201022
,
0x10201221
,
0x10211121
,
0x10221020
,
0x10221122
,
0x10221220
,
0x10221221
,
0x11201020
,
0x11201121
,
0x11201220
,
0x11201222
,
0x11211021
,
0x11211120
,
0x11211121
,
0x11211122
,
0x11211220
,
0x11211222
,
0x11221020
,
0x11221121
,
0x11221220
,
0x12201020
,
0x12201022
,
0x12201121
,
0x12201222
,
0x12211120
,
0x12211122
,
0x12211220
,
0x12211221
,
0x12221020
,
0x12221120
,
0x12221122
,
0x12221222
,
0x10212102
,
0x10212201
,
0x10222101
,
0x11202001
,
0x11212002
,
0x11212101
,
0x11212202
,
0x11222001
,
0x11222201
,
0x12202101
,
0x12212001
,
0x12212200
,
0x12222102
,
0x10202011
,
0x10202110
,
0x10212010
,
0x10212111
,
0x10222011
,
0x10222110
,
0x10222112
,
0x10222211
,
0x11202010
,
0x11202011
,
0x11202111
,
0x11202112
,
0x11202210
,
0x11212011
,
0x11212110
,
0x11212111
,
0x11212112
,
0x11212211
,
0x11222010
,
0x11222111
,
0x11222212
,
0x12202012
,
0x12202110
,
0x12202212
,
0x12212111
,
0x12222011
,
0x12222110
,
0x12222111
,
0x12222211
,
0x10212021
,
0x10212122
,
0x10212220
,
0x11202021
,
0x11202120
,
0x11202221
,
0x11212020
,
0x11212121
,
0x11212220
,
0x11212222
,
0x11222120
,
0x11222121
,
0x11222221
,
0x12202122
,
0x12212120
,
0x12212220
,
0x12212222
,
0x12222122
,
0x20000000
,
0x20000002
,
0x20000200
,
0x20000202
,
0x20020000
,
0x20020002
,
0x20020200
,
0x20020202
,
0x21000101
,
0x21010000
,
0x21010001
,
0x21010100
,
0x21010102
,
0x21010201
,
0x21020101
,
0x22000000
,
0x22000002
,
0x22000200
,
0x22000202
,
0x22010101
,
0x22020000
,
0x22020002
,
0x22020200
,
0x22020202
,
0x20000111
,
0x20010011
,
0x20010110
,
0x20010112
,
0x20010211
,
0x20020111
,
0x21000011
,
0x21000110
,
0x21000211
,
0x21010010
,
0x21010012
,
0x21010111
,
0x21010112
,
0x21010210
,
0x21010211
,
0x21020110
,
0x21020112
,
0x21020211
,
0x22000111
,
0x22000211
,
0x22010110
,
0x22010112
,
0x22010211
,
0x22020111
,
0x20000020
,
0x20000022
,
0x20000220
,
0x20000222
,
0x20010121
,
0x20020020
,
0x20020022
,
0x20020220
,
0x20020222
,
0x21010021
,
0x21010120
,
0x21010221
,
0x21020121
,
0x22000020
,
0x22000022
,
0x22000220
,
0x22000222
,
0x22010121
,
0x22020020
,
0x22020022
,
0x22020220
,
0x22020222
,
0x20011100
,
0x20011201
,
0x21001001
,
0x21001100
,
0x21011001
,
0x21011101
,
0x21011202
,
0x21021001
,
0x21021100
,
0x21021201
,
0x22011100
,
0x22011201
,
0x20001011
,
0x20001211
,
0x20011012
,
0x20011111
,
0x20011212
,
0x20021112
,
0x20021211
,
0x21001010
,
0x21001011
,
0x21001111
,
0x21001210
,
0x21011011
,
0x21011110
,
0x21011111
,
0x21011112
,
0x21011211
,
0x21011212
,
0x21021111
,
0x21021112
,
0x21021210
,
0x21021212
,
0x22001011
,
0x22001110
,
0x22001112
,
0x22001211
,
0x22011010
,
0x22011012
,
0x22011111
,
0x22011210
,
0x22021112
,
0x20011021
,
0x20011122
,
0x20011221
,
0x20021121
,
0x21001021
,
0x21001120
,
0x21001221
,
0x21001222
,
0x21011020
,
0x21011121
,
0x21011221
,
0x21011222
,
0x21021021
,
0x21021122
,
0x21021222
,
0x22001121
,
0x22011021
,
0x22011222
,
0x22021120
,
0x20002000
,
0x20002002
,
0x20002200
,
0x20002202
,
0x20012101
,
0x20022000
,
0x20022002
,
0x20022200
,
0x20022202
,
0x21002001
,
0x21002101
,
0x21012001
,
0x21012100
,
0x21012201
,
0x21022101
,
0x21022201
,
0x22002000
,
0x22002002
,
0x22002200
,
0x22002202
,
0x22012101
,
0x22022000
,
0x22022002
,
0x22022200
,
0x22022202
,
0x20002111
,
0x20002112
,
0x20012011
,
0x20012110
,
0x20012112
,
0x20022111
,
0x21002011
,
0x21002110
,
0x21002112
,
0x21002211
,
0x21012010
,
0x21012012
,
0x21012111
,
0x21012212
,
0x21022011
,
0x21022110
,
0x22002111
,
0x22012112
,
0x22012211
,
0x22022111
,
0x20002020
,
0x20002022
,
0x20002220
,
0x20002222
,
0x20012121
,
0x20022020
,
0x20022022
,
0x20022220
,
0x20022222
,
0x21002121
,
0x21012021
,
0x21012120
,
0x21012122
,
0x22002020
,
0x22002022
,
0x22002220
,
0x22002222
,
0x22012121
,
0x22022020
,
0x22022022
,
0x22022220
,
0x22022222
,
0x20100101
,
0x20110001
,
0x20110102
,
0x20110200
,
0x20110201
,
0x20120101
,
0x21100001
,
0x21100102
,
0x21100201
,
0x21110101
,
0x21110200
,
0x21110202
,
0x21120201
,
0x21120202
,
0x22100101
,
0x22110001
,
0x22110100
,
0x22110102
,
0x22110201
,
0x22120101
,
0x20100011
,
0x20100110
,
0x20100112
,
0x20100211
,
0x20110010
,
0x20110111
,
0x20110210
,
0x20110212
,
0x20120011
,
0x20120110
,
0x20120112
,
0x20120211
,
0x21100010
,
0x21100111
,
0x21110010
,
0x21110011
,
0x21110110
,
0x21110111
,
0x21110112
,
0x21110211
,
0x21120012
,
0x21120111
,
0x22100110
,
0x22100112
,
0x22110012
,
0x22110111
,
0x22110210
,
0x22120011
,
0x22120110
,
0x22120112
,
0x22120211
,
0x20100121
,
0x20110021
,
0x20110120
,
0x20110221
,
0x20120121
,
0x21100120
,
0x21100122
,
0x21100221
,
0x21110020
,
0x21110022
,
0x21110121
,
0x21110220
,
0x21120122
,
0x21120221
,
0x22100121
,
0x22110120
,
0x22110122
,
0x22120221
,
0x20101001
,
0x20101100
,
0x20101102
,
0x20111000
,
0x20111101
,
0x20111200
,
0x20121102
,
0x21101000
,
0x21101202
,
0x21111001
,
0x21111100
,
0x21111101
,
0x21111102
,
0x21111200
,
0x21111201
,
0x21121000
,
0x21121001
,
0x21121002
,
0x21121101
,
0x22101100
,
0x22101102
,
0x22111002
,
0x22111100
,
0x22111101
,
0x22111200
,
0x22121001
,
0x22121201
,
0x20101010
,
0x20101111
,
0x20101210
,
0x20101212
,
0x20111010
,
0x20111011
,
0x20111110
,
0x20111111
,
0x20111112
,
0x20111211
,
0x20121011
,
0x20121111
,
0x20121211
,
0x20121212
,
0x21101011
,
0x21101110
,
0x21101111
,
0x21101112
,
0x21101211
,
0x21111010
,
0x21111011
,
0x21111012
,
0x21111110
,
0x21111111
,
0x21111112
,
0x21111210
,
0x21111211
,
0x21111212
,
0x21121011
,
0x21121110
,
0x21121111
,
0x21121112
,
0x21121211
,
0x22101011
,
0x22101111
,
0x22101210
,
0x22111011
,
0x22111012
,
0x22111110
,
0x22111111
,
0x22111112
,
0x22111211
,
0x22111212
,
0x22121010
,
0x22121012
,
0x22121111
,
0x22121210
,
0x22121212
,
0x20101021
,
0x20101120
,
0x20111020
,
0x20111121
,
0x20111221
,
0x20121020
,
0x20121122
,
0x20121221
,
0x21101121
,
0x21101220
,
0x21101221
,
0x21111021
,
0x21111022
,
0x21111121
,
0x21111122
,
0x21111221
,
0x21121121
,
0x21121220
,
0x22101022
,
0x22101120
,
0x22101221
,
0x22101222
,
0x22111022
,
0x22111120
,
0x22111121
,
0x22121120
,
0x22121122
,
0x22121221
,
0x20102101
,
0x20112102
,
0x20112201
,
0x20122101
,
0x21102001
,
0x21102102
,
0x21112000
,
0x21112002
,
0x21112101
,
0x21112102
,
0x21112202
,
0x21122100
,
0x21122101
,
0x22102101
,
0x22112001
,
0x22112102
,
0x22112201
,
0x22122101
,
0x20102110
,
0x20102112
,
0x20102211
,
0x20112010
,
0x20112012
,
0x20112111
,
0x20112210
,
0x20112212
,
0x20122010
,
0x20122011
,
0x20122110
,
0x20122112
,
0x21102010
,
0x21102012
,
0x21102111
,
0x21102210
,
0x21102212
,
0x21112011
,
0x21112110
,
0x21112111
,
0x21112112
,
0x21112211
,
0x21122012
,
0x21122111
,
0x21122112
,
0x21122212
,
0x22102011
,
0x22102110
,
0x22112010
,
0x22112012
,
0x22112111
,
0x22112212
,
0x22122011
,
0x22122112
,
0x20102121
,
0x20112121
,
0x20122121
,
0x21102120
,
0x21102122
,
0x21102221
,
0x21112020
,
0x21112121
,
0x21112220
,
0x21122021
,
0x22102121
,
0x22112021
,
0x22112120
,
0x22112121
,
0x22112122
,
0x20200000
,
0x20200002
,
0x20200200
,
0x20200202
,
0x20210101
,
0x20220000
,
0x20220002
,
0x20220200
,
0x20220202
,
0x21200101
,
0x21210001
,
0x21210100
,
0x21210102
,
0x21210201
,
0x22200000
,
0x22200002
,
0x22200200
,
0x22200202
,
0x22210101
,
0x22220000
,
0x22220002
,
0x22220200
,
0x22220202
,
0x20200111
,
0x20200211
,
0x20210011
,
0x20210110
,
0x20210112
,
0x20210211
,
0x20210212
,
0x21200112
,
0x21200211
,
0x21210011
,
0x21210111
,
0x21210210
,
0x21210212
,
0x21220011
,
0x21220110
,
0x22200111
,
0x22210010
,
0x22210012
,
0x22210112
,
0x22210211
,
0x20200022
,
0x20200220
,
0x20200222
,
0x20210020
,
0x20210221
,
0x20220022
,
0x20220220
,
0x20220222
,
0x21200121
,
0x21210021
,
0x21210122
,
0x21210221
,
0x21220121
,
0x22200020
,
0x22200022
,
0x22200220
,
0x22200222
,
0x22210121
,
0x22220020
,
0x22220022
,
0x22220220
,
0x22220222
,
0x20211201
,
0x20221101
,
0x21201001
,
0x21201100
,
0x21211000
,
0x21211100
,
0x21211101
,
0x21211200
,
0x21211202
,
0x21221001
,
0x21221101
,
0x21221102
,
0x21221200
,
0x21221201
,
0x22201101
,
0x20201112
,
0x20201211
,
0x20211010
,
0x20211012
,
0x20211111
,
0x20211210
,
0x20221112
,
0x20221211
,
0x21201012
,
0x21201111
,
0x21211011
,
0x21211110
,
0x21211111
,
0x21211112
,
0x21211211
,
0x21221111
,
0x21221212
,
0x22201011
,
0x22201110
,
0x22201111
,
0x22201112
,
0x22201211
,
0x22211012
,
0x22211111
,
0x22211210
,
0x20201121
,
0x20211021
,
0x20211122
,
0x20211222
,
0x20221021
,
0x20221121
,
0x21201120
,
0x21201122
,
0x21201222
,
0x21211022
,
0x21211121
,
0x21211122
,
0x21211220
,
0x21221020
,
0x21221022
,
0x22201122
,
0x22211020
,
0x22211121
,
0x22211122
,
0x22211221
,
0x22221021
,
0x22221120
,
0x22221122
,
0x20202000
,
0x20202002
,
0x20202200
,
0x20202202
,
0x20222000
,
0x20222002
,
0x20222200
,
0x20222202
,
0x21212001
,
0x21212100
,
0x21212102
,
0x21212201
,
0x22202000
,
0x22202002
,
0x22202200
,
0x22202202
,
0x22212101
,
0x22222000
,
0x22222002
,
0x22222200
,
0x22222202
,
0x20202111
,
0x20212110
,
0x20212211
,
0x20222011
,
0x20222111
,
0x21202011
,
0x21212010
,
0x21212111
,
0x21212212
,
0x21222011
,
0x21222112
,
0x21222211
,
0x22212010
,
0x22212112
,
0x20202020
,
0x20202022
,
0x20202220
,
0x20202222
,
0x20222020
,
0x20222022
,
0x20222220
,
0x20222222
,
0x21212021
,
0x21212120
,
0x21212122
,
0x22202020
,
0x22202022
,
0x22202220
,
0x22202222
,
0x22212121
,
0x22222020
,
0x22222022
,
0x22222220
,
0x22222222
,
};
static
const
__device__
uint8_t
ksigns_iq2xs
[
128
]
=
{
0
,
129
,
130
,
3
,
132
,
5
,
6
,
135
,
136
,
9
,
10
,
139
,
12
,
141
,
142
,
15
,
144
,
17
,
18
,
147
,
20
,
149
,
150
,
23
,
24
,
153
,
154
,
27
,
156
,
29
,
30
,
159
,
160
,
33
,
34
,
163
,
36
,
165
,
166
,
39
,
40
,
169
,
170
,
43
,
172
,
45
,
46
,
175
,
48
,
177
,
178
,
51
,
180
,
53
,
54
,
183
,
184
,
57
,
58
,
187
,
60
,
189
,
190
,
63
,
192
,
65
,
66
,
195
,
68
,
197
,
198
,
71
,
72
,
201
,
202
,
75
,
204
,
77
,
78
,
207
,
80
,
209
,
210
,
83
,
212
,
85
,
86
,
215
,
216
,
89
,
90
,
219
,
92
,
221
,
222
,
95
,
96
,
225
,
226
,
99
,
228
,
101
,
102
,
231
,
232
,
105
,
106
,
235
,
108
,
237
,
238
,
111
,
240
,
113
,
114
,
243
,
116
,
245
,
246
,
119
,
120
,
249
,
250
,
123
,
252
,
125
,
126
,
255
,
};
static
const
__device__
uint64_t
ksigns64
[
128
]
=
{
0x0000000000000000
,
0xff000000000000ff
,
0xff0000000000ff00
,
0x000000000000ffff
,
0xff00000000ff0000
,
0x0000000000ff00ff
,
0x0000000000ffff00
,
0xff00000000ffffff
,
0xff000000ff000000
,
0x00000000ff0000ff
,
0x00000000ff00ff00
,
0xff000000ff00ffff
,
0x00000000ffff0000
,
0xff000000ffff00ff
,
0xff000000ffffff00
,
0x00000000ffffffff
,
0xff0000ff00000000
,
0x000000ff000000ff
,
0x000000ff0000ff00
,
0xff0000ff0000ffff
,
0x000000ff00ff0000
,
0xff0000ff00ff00ff
,
0xff0000ff00ffff00
,
0x000000ff00ffffff
,
0x000000ffff000000
,
0xff0000ffff0000ff
,
0xff0000ffff00ff00
,
0x000000ffff00ffff
,
0xff0000ffffff0000
,
0x000000ffffff00ff
,
0x000000ffffffff00
,
0xff0000ffffffffff
,
0xff00ff0000000000
,
0x0000ff00000000ff
,
0x0000ff000000ff00
,
0xff00ff000000ffff
,
0x0000ff0000ff0000
,
0xff00ff0000ff00ff
,
0xff00ff0000ffff00
,
0x0000ff0000ffffff
,
0x0000ff00ff000000
,
0xff00ff00ff0000ff
,
0xff00ff00ff00ff00
,
0x0000ff00ff00ffff
,
0xff00ff00ffff0000
,
0x0000ff00ffff00ff
,
0x0000ff00ffffff00
,
0xff00ff00ffffffff
,
0x0000ffff00000000
,
0xff00ffff000000ff
,
0xff00ffff0000ff00
,
0x0000ffff0000ffff
,
0xff00ffff00ff0000
,
0x0000ffff00ff00ff
,
0x0000ffff00ffff00
,
0xff00ffff00ffffff
,
0xff00ffffff000000
,
0x0000ffffff0000ff
,
0x0000ffffff00ff00
,
0xff00ffffff00ffff
,
0x0000ffffffff0000
,
0xff00ffffffff00ff
,
0xff00ffffffffff00
,
0x0000ffffffffffff
,
0xffff000000000000
,
0x00ff0000000000ff
,
0x00ff00000000ff00
,
0xffff00000000ffff
,
0x00ff000000ff0000
,
0xffff000000ff00ff
,
0xffff000000ffff00
,
0x00ff000000ffffff
,
0x00ff0000ff000000
,
0xffff0000ff0000ff
,
0xffff0000ff00ff00
,
0x00ff0000ff00ffff
,
0xffff0000ffff0000
,
0x00ff0000ffff00ff
,
0x00ff0000ffffff00
,
0xffff0000ffffffff
,
0x00ff00ff00000000
,
0xffff00ff000000ff
,
0xffff00ff0000ff00
,
0x00ff00ff0000ffff
,
0xffff00ff00ff0000
,
0x00ff00ff00ff00ff
,
0x00ff00ff00ffff00
,
0xffff00ff00ffffff
,
0xffff00ffff000000
,
0x00ff00ffff0000ff
,
0x00ff00ffff00ff00
,
0xffff00ffff00ffff
,
0x00ff00ffffff0000
,
0xffff00ffffff00ff
,
0xffff00ffffffff00
,
0x00ff00ffffffffff
,
0x00ffff0000000000
,
0xffffff00000000ff
,
0xffffff000000ff00
,
0x00ffff000000ffff
,
0xffffff0000ff0000
,
0x00ffff0000ff00ff
,
0x00ffff0000ffff00
,
0xffffff0000ffffff
,
0xffffff00ff000000
,
0x00ffff00ff0000ff
,
0x00ffff00ff00ff00
,
0xffffff00ff00ffff
,
0x00ffff00ffff0000
,
0xffffff00ffff00ff
,
0xffffff00ffffff00
,
0x00ffff00ffffffff
,
0xffffffff00000000
,
0x00ffffff000000ff
,
0x00ffffff0000ff00
,
0xffffffff0000ffff
,
0x00ffffff00ff0000
,
0xffffffff00ff00ff
,
0xffffffff00ffff00
,
0x00ffffff00ffffff
,
0x00ffffffff000000
,
0xffffffffff0000ff
,
0xffffffffff00ff00
,
0x00ffffffff00ffff
,
0xffffffffffff0000
,
0x00ffffffffff00ff
,
0x00ffffffffffff00
,
0xffffffffffffffff
,
};
static
const
__device__
uint8_t
kmask_iq2xs
[
8
]
=
{
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
};
static
const
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
typedef
half
dfloat
;
// dequantize float
typedef
half2
dfloat2
;
typedef
void
(
*
dequantize_kernel_t
)(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
);
template
<
typename
dst_t
>
using
to_cuda_ggml_t
=
void
(
*
)(
const
void
*
__restrict__
x
,
dst_t
*
__restrict__
y
,
int
k
,
cudaStream_t
stream
);
typedef
float
(
*
vec_dot_q_cuda_t
)(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
);
typedef
void
(
*
allocate_tiles_cuda_t
)(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
);
typedef
void
(
*
load_tiles_cuda_t
)(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
);
typedef
float
(
*
vec_dot_q_mul_mat_cuda_t
)(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ms
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
);
// Utility function
template
<
typename
dst_t
>
static
__device__
__forceinline__
dst_t
convert_from_half
(
half
val
)
{
return
val
;
}
template
<
>
__device__
__forceinline__
c10
::
BFloat16
convert_from_half
<
c10
::
BFloat16
>
(
half
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return
__float2bfloat16
(
__half2float
(
val
));
#else
return
__half2float
(
val
);
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
template
<
>
__device__
__forceinline__
float
convert_from_half
<
float
>
(
half
val
)
{
return
__half2float
(
val
);
}
#if defined(USE_ROCM)
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef
int8_t
int8x4_t
__attribute__
((
ext_vector_type
(
4
)));
static
__device__
__forceinline__
int
__vsubss4
(
const
int
a
,
const
int
b
)
{
const
int8x4_t
va
=
reinterpret_cast
<
const
int8x4_t
&>
(
a
);
const
int8x4_t
vb
=
reinterpret_cast
<
const
int8x4_t
&>
(
b
);
#if __has_builtin(__builtin_elementwise_sub_sat)
const
int8x4_t
c
=
__builtin_elementwise_sub_sat
(
va
,
vb
);
return
reinterpret_cast
<
const
int
&>
(
c
);
#else
int8x4_t
c
;
int16_t
tmp
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp
=
va
[
i
]
-
vb
[
i
];
if
(
tmp
>
std
::
numeric_limits
<
int8_t
>::
max
())
tmp
=
std
::
numeric_limits
<
int8_t
>::
max
();
if
(
tmp
<
std
::
numeric_limits
<
int8_t
>::
min
())
tmp
=
std
::
numeric_limits
<
int8_t
>::
min
();
c
[
i
]
=
tmp
;
}
return
reinterpret_cast
<
int
&>
(
c
);
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
static
__device__
__forceinline__
int
__dp4a
(
const
int
a
,
const
int
b
,
int
c
)
{
#if __has_builtin(__builtin_amdgcn_sdot4)
c
=
__builtin_amdgcn_sdot4
(
a
,
b
,
c
,
false
);
#else
const
int8x4_t
va
=
reinterpret_cast
<
const
int8x4_t
&>
(
a
);
const
int8x4_t
vb
=
reinterpret_cast
<
const
int8x4_t
&>
(
b
);
c
+=
va
[
0
]
*
vb
[
0
]
+
va
[
1
]
*
vb
[
1
]
+
va
[
2
]
*
vb
[
2
]
+
va
[
3
]
*
vb
[
3
];
#endif
return
c
;
}
static
__device__
__forceinline__
uint32_t
__vcmpeq4
(
const
uint32_t
a
,
const
uint32_t
b
)
{
uint32_t
neq
=
a
^
b
;
return
!
(
neq
&
0xff000000
)
*
0xff000000
|
!
(
neq
&
0x00ff0000
)
*
0x00ff0000
|
!
(
neq
&
0x0000ff00
)
*
0x0000ff00
|
!
(
neq
&
0x000000ff
)
*
0x000000ff
;
}
static
__device__
__forceinline__
uint32_t
__vsub4
(
const
uint32_t
a
,
const
uint32_t
b
)
{
return
(
static_cast
<
uint8_t
>
(((
a
&
0xff000000
)
>>
24
)
-
((
b
&
0xff000000
)
>>
24
))
<<
24
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x00ff0000
)
>>
16
)
-
((
b
&
0x00ff0000
)
>>
16
))
<<
16
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x0000ff00
)
>>
8
)
-
((
b
&
0x0000ff00
)
>>
8
))
<<
8
)
+
(
static_cast
<
uint8_t
>
(((
a
&
0x000000ff
)
>>
0
)
-
((
b
&
0x000000ff
)
>>
0
))
<<
0
);
}
#endif // defined(USE_ROCM)
sgl-kernel/csrc/quantization/gguf/gguf_kernel.cu
0 → 100644
View file @
8fdcd98e
// Adatped from
// https://github.com/vllm-project/vllm/blob/755ed7b05be4743237d3339c4ff8c22bcaae04f4/csrc/quantization/gguf/gguf_kernel.cu
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
// dont use clang-format here, it breaks the include order
// clang-format off
#include "utils.h"
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
#include "moe.cuh"
#include "moe_vec.cuh"
// clang-format off
// Q8 gemv
template
<
typename
scalar_t
>
static
__global__
void
quantize_q8_1
(
const
scalar_t
*
__restrict__
x
,
void
*
__restrict__
vy
,
const
int
kx
,
const
int
kx_padded
)
{
const
auto
ix
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
ix
>=
kx_padded
)
{
return
;
}
const
auto
iy
=
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
;
const
int
i_padded
=
iy
*
kx_padded
+
ix
;
block_q8_1
*
y
=
(
block_q8_1
*
)
vy
;
const
int
ib
=
i_padded
/
QK8_1
;
// block index
const
int
iqs
=
i_padded
%
QK8_1
;
// quant index
const
float
xi
=
ix
<
kx
?
static_cast
<
float
>
(
x
[
iy
*
kx
+
ix
])
:
0.0
f
;
float
amax
=
fabsf
(
xi
);
float
sum
=
xi
;
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
{
amax
=
fmaxf
(
amax
,
SGLANG_SHFL_XOR_SYNC_WIDTH
(
uint32_t
(
-
1
),
amax
,
mask
,
32
));
sum
+=
SGLANG_SHFL_XOR_SYNC_WIDTH
(
uint32_t
(
-
1
),
sum
,
mask
,
32
);
}
const
float
d
=
amax
/
127
;
const
int8_t
q
=
amax
==
0.0
f
?
0
:
roundf
(
xi
/
d
);
y
[
ib
].
qs
[
iqs
]
=
q
;
if
(
iqs
>
0
)
{
return
;
}
y
[
ib
].
ds
.
x
=
__float2half
(
d
);
y
[
ib
].
ds
.
y
=
__float2half
(
sum
);
}
template
<
typename
scalar_t
>
static
void
quantize_row_q8_1_cuda
(
const
scalar_t
*
x
,
void
*
vy
,
const
int
kx
,
const
int
ky
,
cudaStream_t
stream
)
{
const
int64_t
kx_padded
=
(
kx
+
512
-
1
)
/
512
*
512
;
const
int
block_num_x
=
(
kx_padded
+
CUDA_QUANTIZE_BLOCK_SIZE
-
1
)
/
CUDA_QUANTIZE_BLOCK_SIZE
;
constexpr
int
MAX_BLOCK_SIZE
=
65535
;
for
(
int
off
=
0
;
off
<
ky
;
off
+=
MAX_BLOCK_SIZE
)
{
const
int
num_blocks_y
=
std
::
min
(
ky
,
off
+
MAX_BLOCK_SIZE
)
-
off
;
const
dim3
num_blocks
(
block_num_x
,
num_blocks_y
,
1
);
const
dim3
block_size
(
CUDA_DEQUANTIZE_BLOCK_SIZE
,
1
,
1
);
quantize_q8_1
<<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
&
x
[
off
*
kx
],
(
int32_t
*
)
vy
+
off
*
(
kx_padded
/
32
*
9
),
kx
,
kx_padded
);
}
}
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
// quant weight
int64_t
type
,
int64_t
m
,
int64_t
n
,
std
::
optional
<
at
::
ScalarType
>
const
&
dtype
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
W
));
auto
dtype_
=
dtype
.
value_or
(
torch
::
kFloat16
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
dtype_
).
device
(
W
.
device
());
at
::
Tensor
DW
=
torch
::
empty
({
m
,
n
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
DISPATCH_FLOAT_TYPES
(
DW
.
scalar_type
(),
"ggml_dequantize"
,
[
&
]
{
auto
to_cuda
=
ggml_get_to_cuda
<
scalar_t
>
(
type
);
to_cuda
((
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
DW
.
data_ptr
(),
m
*
n
,
stream
);
});
return
DW
;
}
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
// quant weight
torch
::
Tensor
X
,
// input
int64_t
type
,
int64_t
row
)
{
int
col
=
X
.
sizes
()[
1
];
int
vecs
=
X
.
sizes
()[
0
];
const
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
X
.
dtype
()).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
empty
({
vecs
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
vecs
,
padded
/
32
*
9
},
options
);
DISPATCH_FLOAT_TYPES
(
X
.
scalar_type
(),
"ggml_mul_mat_vec_a8"
,
[
&
]
{
quantize_row_q8_1_cuda
<
scalar_t
>
((
scalar_t
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
vecs
,
stream
);
switch
(
type
)
{
case
2
:
mul_mat_vec_q4_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
3
:
mul_mat_vec_q4_1_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
6
:
mul_mat_vec_q5_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
7
:
mul_mat_vec_q5_1_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
8
:
mul_mat_vec_q8_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
10
:
mul_mat_vec_q2_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
11
:
mul_mat_vec_q3_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
12
:
mul_mat_vec_q4_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
13
:
mul_mat_vec_q5_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
14
:
mul_mat_vec_q6_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
16
:
mul_mat_vec_iq2_xxs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
17
:
mul_mat_vec_iq2_xs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
18
:
mul_mat_vec_iq3_xxs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
19
:
mul_mat_vec_iq1_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
20
:
mul_mat_vec_iq4_nl_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
21
:
mul_mat_vec_iq3_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
22
:
mul_mat_vec_iq2_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
23
:
mul_mat_vec_iq4_xs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
case
29
:
mul_mat_vec_iq1_m_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
vecs
,
stream
);
break
;
}
});
return
Y
;
}
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
// quant weight
torch
::
Tensor
X
,
// input
int64_t
type
,
int64_t
row
)
{
int
col
=
X
.
sizes
()[
1
];
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
int
batch
=
X
.
sizes
()[
0
];
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
X
.
dtype
()).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
empty
({
batch
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
batch
,
padded
/
32
*
9
},
options
);
DISPATCH_FLOAT_TYPES
(
X
.
scalar_type
(),
"ggml_mul_mat_a8"
,
[
&
]
{
quantize_row_q8_1_cuda
((
scalar_t
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
batch
,
stream
);
switch
(
type
)
{
case
2
:
ggml_mul_mat_q4_0_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
3
:
ggml_mul_mat_q4_1_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
6
:
ggml_mul_mat_q5_0_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
7
:
ggml_mul_mat_q5_1_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
8
:
ggml_mul_mat_q8_0_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
10
:
ggml_mul_mat_q2_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
11
:
ggml_mul_mat_q3_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
12
:
ggml_mul_mat_q4_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
13
:
ggml_mul_mat_q5_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
14
:
ggml_mul_mat_q6_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
}
});
return
Y
;
}
torch
::
Tensor
ggml_moe_a8
(
torch
::
Tensor
X
,
// input
torch
::
Tensor
W
,
// expert weights
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
)
{
int
col
=
X
.
sizes
()[
1
];
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
X
.
dtype
()).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
empty
({
tokens
*
top_k
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
tokens
,
padded
/
32
*
9
},
options
);
DISPATCH_FLOAT_TYPES
(
X
.
scalar_type
(),
"ggml_moe_a8"
,
[
&
]
{
quantize_row_q8_1_cuda
((
scalar_t
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
tokens
,
stream
);
switch
(
type
)
{
case
2
:
ggml_moe_q4_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
3
:
ggml_moe_q4_1_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
6
:
ggml_moe_q5_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
7
:
ggml_moe_q5_1_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
8
:
ggml_moe_q8_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
10
:
ggml_moe_q2_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
11
:
ggml_moe_q3_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
12
:
ggml_moe_q4_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
13
:
ggml_moe_q5_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
14
:
ggml_moe_q6_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
}
});
return
Y
;
}
torch
::
Tensor
ggml_moe_a8_vec
(
torch
::
Tensor
X
,
// input
torch
::
Tensor
W
,
// expert weights
torch
::
Tensor
topk_ids
,
int64_t
top_k
,
int64_t
type
,
int64_t
row
,
int64_t
tokens
)
{
int
col
=
X
.
sizes
()[
1
];
const
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
X
.
dtype
()).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
zeros
({
tokens
*
top_k
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
tokens
,
padded
/
32
*
9
},
options
);
DISPATCH_FLOAT_TYPES
(
X
.
scalar_type
(),
"ggml_moe_vec_a8"
,
[
&
]
{
quantize_row_q8_1_cuda
<
scalar_t
>
((
scalar_t
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
tokens
,
stream
);
switch
(
type
)
{
case
2
:
moe_vec_q4_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
3
:
moe_vec_q4_1_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
6
:
moe_vec_q5_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
7
:
moe_vec_q5_1_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
8
:
moe_vec_q8_0_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
10
:
moe_vec_q2_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
11
:
moe_vec_q3_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
12
:
moe_vec_q4_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
13
:
moe_vec_q5_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
14
:
moe_vec_q6_K_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
16
:
moe_vec_iq2_xxs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
17
:
moe_vec_iq2_xs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
18
:
moe_vec_iq3_xxs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
19
:
moe_vec_iq1_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
20
:
moe_vec_iq4_nl_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
21
:
moe_vec_iq3_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
22
:
moe_vec_iq2_s_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
23
:
moe_vec_iq4_xs_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
case
29
:
moe_vec_iq1_m_q8_1_cuda
<
scalar_t
>
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
topk_ids
.
data_ptr
(),
top_k
,
tokens
,
col
,
row
,
quant_X
.
stride
(
0
),
stream
);
break
;
}
});
return
Y
;
}
int64_t
ggml_moe_get_block_size
(
int64_t
type
)
{
switch
(
type
)
{
case
2
:
return
MOE_X_Q4_0
;
case
3
:
return
MOE_X_Q4_1
;
case
6
:
return
MOE_X_Q5_0
;
case
7
:
return
MOE_X_Q5_1
;
case
8
:
return
MOE_X_Q8_0
;
case
10
:
return
MOE_X_Q2_K
;
case
11
:
return
MOE_X_Q3_K
;
case
12
:
return
MOE_X_Q4_K
;
case
13
:
return
MOE_X_Q5_K
;
case
14
:
return
MOE_X_Q6_K
;
}
return
0
;
}
sgl-kernel/csrc/quantization/gguf/mmq.cuh
0 → 100644
View file @
8fdcd98e
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/mmq.cuh
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
template
<
typename
scalar_t
,
int
qk
,
int
qr
,
int
qi
,
bool
need_sum
,
typename
block_q_t
,
int
mmq_x
,
int
mmq_y
,
int
nwarps
,
allocate_tiles_cuda_t
allocate_tiles
,
load_tiles_cuda_t
load_tiles
,
int
vdr
,
vec_dot_q_mul_mat_cuda_t
vec_dot
>
static
__device__
__forceinline__
void
mul_mat_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
block_q_t
*
x
=
(
const
block_q_t
*
)
vx
;
const
block_q8_1
*
y
=
(
const
block_q8_1
*
)
vy
;
const
int
blocks_per_row_x
=
ncols_x
/
qk
;
const
int
blocks_per_col_y
=
nrows_y
/
QK8_1
;
const
int
blocks_per_warp
=
WARP_SIZE_GGUF
/
qi
;
const
int
&
ncols_dst
=
ncols_y
;
const
auto
row_dst_0
=
blockIdx
.
x
*
mmq_y
;
const
int
&
row_x_0
=
row_dst_0
;
const
auto
col_dst_0
=
blockIdx
.
y
*
mmq_x
;
const
int
&
col_y_0
=
col_dst_0
;
int
*
tile_x_ql
=
nullptr
;
half2
*
tile_x_dm
=
nullptr
;
int
*
tile_x_qh
=
nullptr
;
int
*
tile_x_sc
=
nullptr
;
allocate_tiles
(
&
tile_x_ql
,
&
tile_x_dm
,
&
tile_x_qh
,
&
tile_x_sc
);
__shared__
int
tile_y_qs
[
mmq_x
*
WARP_SIZE_GGUF
];
__shared__
half2
tile_y_ds
[
mmq_x
*
WARP_SIZE_GGUF
/
QI8_1
];
float
sum
[
mmq_y
/
WARP_SIZE_GGUF
][
mmq_x
/
nwarps
]
=
{{
0.0
f
}};
for
(
int
ib0
=
0
;
ib0
<
blocks_per_row_x
;
ib0
+=
blocks_per_warp
)
{
load_tiles
(
x
+
row_x_0
*
blocks_per_row_x
+
ib0
,
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
threadIdx
.
y
,
nrows_x
-
row_x_0
-
1
,
threadIdx
.
x
,
blocks_per_row_x
);
#pragma unroll
for
(
int
ir
=
0
;
ir
<
qr
&&
ib0
+
ir
*
blocks_per_warp
/
qr
<
blocks_per_row_x
;
++
ir
)
{
const
auto
kqs
=
ir
*
WARP_SIZE_GGUF
+
threadIdx
.
x
;
const
int
kbxd
=
kqs
/
QI8_1
;
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_x
;
i
+=
nwarps
)
{
const
int
col_y_eff
=
min
(
col_y_0
+
threadIdx
.
y
+
i
,
ncols_y
-
1
);
// to prevent out-of-bounds memory accesses
const
block_q8_1
*
by0
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
kbxd
];
const
int
index_y
=
(
threadIdx
.
y
+
i
)
*
WARP_SIZE_GGUF
+
kqs
%
WARP_SIZE_GGUF
;
tile_y_qs
[
index_y
]
=
get_int_from_int8_aligned
(
by0
->
qs
,
threadIdx
.
x
%
QI8_1
);
}
#pragma unroll
for
(
int
ids0
=
0
;
ids0
<
mmq_x
;
ids0
+=
nwarps
*
QI8_1
)
{
const
int
ids
=
(
ids0
+
threadIdx
.
y
*
QI8_1
+
threadIdx
.
x
/
(
WARP_SIZE_GGUF
/
QI8_1
))
%
mmq_x
;
const
auto
kby
=
threadIdx
.
x
%
(
WARP_SIZE_GGUF
/
QI8_1
);
const
int
col_y_eff
=
min
(
col_y_0
+
ids
,
ncols_y
-
1
);
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
const
half2
*
dsi_src
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
ir
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
kby
].
ds
;
half2
*
dsi_dst
=
&
tile_y_ds
[
ids
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
kby
];
if
(
need_sum
)
{
*
dsi_dst
=
*
dsi_src
;
}
else
{
float
*
dfi_dst
=
(
float
*
)
dsi_dst
;
*
dfi_dst
=
__low2float
(
*
dsi_src
);
}
}
__syncthreads
();
// #pragma unroll // unrolling this loop causes too much register pressure
for
(
int
k
=
ir
*
WARP_SIZE_GGUF
/
qr
;
k
<
(
ir
+
1
)
*
WARP_SIZE_GGUF
/
qr
;
k
+=
vdr
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE_GGUF
)
{
sum
[
i
/
WARP_SIZE_GGUF
][
j
/
nwarps
]
+=
vec_dot
(
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
tile_y_qs
,
tile_y_ds
,
threadIdx
.
x
+
i
,
threadIdx
.
y
+
j
,
k
);
}
}
}
__syncthreads
();
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
const
auto
col_dst
=
col_dst_0
+
j
+
threadIdx
.
y
;
if
(
col_dst
>=
ncols_dst
)
{
return
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE_GGUF
)
{
const
auto
row_dst
=
row_dst_0
+
threadIdx
.
x
+
i
;
if
(
row_dst
>=
nrows_dst
)
{
continue
;
}
dst
[
col_dst
*
nrows_dst
+
row_dst
]
=
sum
[
i
/
WARP_SIZE_GGUF
][
j
/
nwarps
];
}
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_0 64
#define MMQ_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define MMQ_X_Q4_0 4
#define MMQ_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_0
,
2
)
#endif
mul_mat_q4_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q4_0
;
const
int
mmq_y
=
MMQ_Y_Q4_0
;
const
int
nwarps
=
NWARPS_Q4_0
;
mul_mat_q
<
scalar_t
,
QK4_0
,
QR4_0
,
QI4_0
,
true
,
block_q4_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_0
<
mmq_y
>
,
load_tiles_q4_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_0_Q8_1_MMQ
,
vec_dot_q4_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q4_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
int
mmq_x
=
MMQ_X_Q4_0
;
int
mmq_y
=
MMQ_Y_Q4_0
;
int
nwarps
=
NWARPS_Q4_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q4_0
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q4_0
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_1 64
#define MMQ_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define MMQ_X_Q4_1 4
#define MMQ_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_1
,
2
)
#endif
mul_mat_q4_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q4_1
;
const
int
mmq_y
=
MMQ_Y_Q4_1
;
const
int
nwarps
=
NWARPS_Q4_1
;
mul_mat_q
<
scalar_t
,
QK4_1
,
QR4_1
,
QI4_1
,
true
,
block_q4_1
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_1
<
mmq_y
>
,
load_tiles_q4_1
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_1_Q8_1_MMQ
,
vec_dot_q4_1_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q4_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
int
mmq_x
=
MMQ_X_Q4_1
;
int
mmq_y
=
MMQ_Y_Q4_1
;
int
nwarps
=
NWARPS_Q4_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q4_1
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q4_1
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_0 64
#define MMQ_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define MMQ_X_Q5_0 4
#define MMQ_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_0
,
2
)
#endif
mul_mat_q5_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q5_0
;
const
int
mmq_y
=
MMQ_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
mul_mat_q
<
scalar_t
,
QK5_0
,
QR5_0
,
QI5_0
,
false
,
block_q5_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_0
<
mmq_y
>
,
load_tiles_q5_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_0_Q8_1_MMQ
,
vec_dot_q5_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q5_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_0
;
const
int
mmq_y
=
MMQ_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q5_0
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q5_0
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_1 64
#define MMQ_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define MMQ_X_Q5_1 4
#define MMQ_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_1
,
2
)
#endif
mul_mat_q5_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q5_1
;
const
int
mmq_y
=
MMQ_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
mul_mat_q
<
scalar_t
,
QK5_1
,
QR5_1
,
QI5_1
,
true
,
block_q5_1
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_1
<
mmq_y
>
,
load_tiles_q5_1
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_1_Q8_1_MMQ
,
vec_dot_q5_1_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q5_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_1
;
const
int
mmq_y
=
MMQ_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q5_1
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q5_1
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q8_0 64
#define MMQ_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define MMQ_X_Q8_0 4
#define MMQ_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q8_0
,
2
)
#endif
mul_mat_q8_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q8_0
;
const
int
mmq_y
=
MMQ_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
mul_mat_q
<
scalar_t
,
QK8_0
,
QR8_0
,
QI8_0
,
false
,
block_q8_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q8_0
<
mmq_y
>
,
load_tiles_q8_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q8_0_Q8_1_MMQ
,
vec_dot_q8_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q8_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q8_0
;
const
int
mmq_y
=
MMQ_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q8_0
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q8_0
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q2_K 64
#define MMQ_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define MMQ_X_Q2_K 4
#define MMQ_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q2_K
,
2
)
#endif
mul_mat_q2_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q2_K
;
const
int
mmq_y
=
MMQ_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
mul_mat_q
<
scalar_t
,
QK_K
,
QR2_K
,
QI2_K
,
false
,
block_q2_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q2_K
<
mmq_y
>
,
load_tiles_q2_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q2_K_Q8_1_MMQ
,
vec_dot_q2_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q2_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q2_K
;
const
int
mmq_y
=
MMQ_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q2_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q2_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q3_K 64
#define MMQ_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define MMQ_X_Q3_K 4
#define MMQ_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q3_K
,
2
)
#endif
mul_mat_q3_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q3_K
;
const
int
mmq_y
=
MMQ_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
mul_mat_q
<
scalar_t
,
QK_K
,
QR3_K
,
QI3_K
,
false
,
block_q3_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q3_K
<
mmq_y
>
,
load_tiles_q3_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q3_K_Q8_1_MMQ
,
vec_dot_q3_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q3_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q3_K
;
const
int
mmq_y
=
MMQ_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q3_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q3_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_K 64
#define MMQ_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define MMQ_X_Q4_K 4
#define MMQ_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_K
,
2
)
#endif
mul_mat_q4_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q4_K
;
const
int
mmq_y
=
MMQ_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
mul_mat_q
<
scalar_t
,
QK_K
,
QR4_K
,
QI4_K
,
true
,
block_q4_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_K
<
mmq_y
>
,
load_tiles_q4_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_K_Q8_1_MMQ
,
vec_dot_q4_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q4_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q4_K
;
const
int
mmq_y
=
MMQ_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q4_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q4_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_K 64
#define MMQ_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define MMQ_X_Q5_K 4
#define MMQ_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_K
,
2
)
#endif
mul_mat_q5_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q5_K
;
const
int
mmq_y
=
MMQ_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
mul_mat_q
<
scalar_t
,
QK_K
,
QR5_K
,
QI5_K
,
true
,
block_q5_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_K
<
mmq_y
>
,
load_tiles_q5_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_K_Q8_1_MMQ
,
vec_dot_q5_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q5_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_K
;
const
int
mmq_y
=
MMQ_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q5_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q5_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q6_K 64
#define MMQ_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define MMQ_X_Q6_K 4
#define MMQ_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q6_K
,
2
)
#endif
mul_mat_q6_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q6_K
;
const
int
mmq_y
=
MMQ_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
mul_mat_q
<
scalar_t
,
QK_K
,
QR6_K
,
QI6_K
,
false
,
block_q6_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q6_K
<
mmq_y
>
,
load_tiles_q6_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q6_K_Q8_1_MMQ
,
vec_dot_q6_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
template
<
typename
scalar_t
>
static
void
ggml_mul_mat_q6_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q6_K
;
const
int
mmq_y
=
MMQ_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q6_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q6_K
<
scalar_t
,
need_check
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
sgl-kernel/csrc/quantization/gguf/mmvq.cuh
0 → 100644
View file @
8fdcd98e
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/mmvq.cuh
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template
<
typename
scalar_t
,
int
qk
,
int
qi
,
typename
block_q_t
,
int
vdr
,
vec_dot_q_cuda_t
vec_dot_q_cuda
>
static
__global__
void
mul_mat_vec_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
)
{
const
auto
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
auto
vec
=
blockIdx
.
y
;
if
(
row
>=
nrows
||
vec
>=
nvecs
)
{
return
;
}
const
int
blocks_per_row
=
ncols
/
qk
;
const
int
blocks_per_warp
=
vdr
*
WARP_SIZE
/
qi
;
const
int
nrows_y
=
(
ncols
+
512
-
1
)
/
512
*
512
;
// partial sum for each thread
float
tmp
=
0.0
f
;
const
block_q_t
*
x
=
(
const
block_q_t
*
)
vx
;
const
block_q8_1
*
y
=
(
const
block_q8_1
*
)
vy
;
for
(
auto
i
=
threadIdx
.
x
/
(
qi
/
vdr
);
i
<
blocks_per_row
;
i
+=
blocks_per_warp
)
{
const
int
ibx
=
row
*
blocks_per_row
+
i
;
// x block index
const
int
iby
=
vec
*
(
nrows_y
/
QK8_1
)
+
i
*
(
qk
/
QK8_1
);
// y block index that aligns with ibx
const
int
iqs
=
vdr
*
(
threadIdx
.
x
%
(
qi
/
vdr
));
// x block quant index when casting the quants to int
tmp
+=
vec_dot_q_cuda
(
&
x
[
ibx
],
&
y
[
iby
],
iqs
);
}
// sum up partial sums and write back result
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>
0
;
mask
>>=
1
)
{
tmp
+=
SGLANG_SHFL_XOR_SYNC
(
uint32_t
(
-
1
),
tmp
,
mask
);
}
if
(
threadIdx
.
x
==
0
)
{
dst
[
vec
*
nrows
+
row
]
=
tmp
;
}
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q4_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK4_0
,
QI4_0
,
block_q4_0
,
VDR_Q4_0_Q8_1_MMVQ
,
vec_dot_q4_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q4_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK4_0
,
QI4_1
,
block_q4_1
,
VDR_Q4_1_Q8_1_MMVQ
,
vec_dot_q4_1_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q5_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK5_0
,
QI5_0
,
block_q5_0
,
VDR_Q5_0_Q8_1_MMVQ
,
vec_dot_q5_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q5_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK5_1
,
QI5_1
,
block_q5_1
,
VDR_Q5_1_Q8_1_MMVQ
,
vec_dot_q5_1_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q8_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK8_0
,
QI8_0
,
block_q8_0
,
VDR_Q8_0_Q8_1_MMVQ
,
vec_dot_q8_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q2_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI2_K
,
block_q2_K
,
VDR_Q2_K_Q8_1_MMVQ
,
vec_dot_q2_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q3_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI3_K
,
block_q3_K
,
VDR_Q3_K_Q8_1_MMVQ
,
vec_dot_q3_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q4_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI4_K
,
block_q4_K
,
VDR_Q4_K_Q8_1_MMVQ
,
vec_dot_q4_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q5_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI5_K
,
block_q5_K
,
VDR_Q5_K_Q8_1_MMVQ
,
vec_dot_q5_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_q6_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI6_K
,
block_q6_K
,
VDR_Q6_K_Q8_1_MMVQ
,
vec_dot_q6_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq2_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI2_XXS
,
block_iq2_xxs
,
1
,
vec_dot_iq2_xxs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq2_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI2_XS
,
block_iq2_xs
,
1
,
vec_dot_iq2_xs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq2_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI2_S
,
block_iq2_s
,
1
,
vec_dot_iq2_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq3_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI3_XXS
,
block_iq3_xxs
,
1
,
vec_dot_iq3_xxs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq1_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI1_S
,
block_iq1_s
,
1
,
vec_dot_iq1_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq1_m_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI1_M
,
block_iq1_m
,
1
,
vec_dot_iq1_m_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq4_nl_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK4_NL
,
QI4_NL
,
block_iq4_nl
,
VDR_Q4_0_Q8_1_MMVQ
,
vec_dot_iq4_nl_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq4_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI4_XS
,
block_iq4_xs
,
1
,
vec_dot_iq4_xs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
template
<
typename
scalar_t
>
static
void
mul_mat_vec_iq3_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nvecs
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
nvecs
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
scalar_t
,
QK_K
,
QI3_XS
,
block_iq3_s
,
1
,
vec_dot_iq3_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
,
nvecs
);
}
sgl-kernel/csrc/quantization/gguf/moe.cuh
0 → 100644
View file @
8fdcd98e
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/moe.cuh
#include <cstdint>
/* Adapted from ./csrc/quantization/gguf/mmq.cuh
*/
template
<
typename
scalar_t
,
int
qk
,
int
qr
,
int
qi
,
bool
need_sum
,
typename
block_q_t
,
int
mmq_x
,
int
mmq_y
,
int
nwarps
,
allocate_tiles_cuda_t
allocate_tiles
,
load_tiles_cuda_t
load_tiles
,
int
vdr
,
vec_dot_q_mul_mat_cuda_t
vec_dot
>
static
__device__
__forceinline__
void
moe_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
__restrict__
sorted_token_ids
,
const
int
*
__restrict__
expert_ids
,
const
int
*
__restrict__
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
blocks_per_row_x
=
ncols_x
/
qk
;
const
int
blocks_per_col_y
=
nrows_y
/
QK8_1
;
const
int
blocks_per_warp
=
WARP_SIZE_GGUF
/
qi
;
const
int
ncols_dst
=
ncols_y
*
top_k
;
const
auto
row_dst_0
=
blockIdx
.
x
*
mmq_y
;
const
int
&
row_x_0
=
row_dst_0
;
const
auto
col_dst_0
=
blockIdx
.
y
*
mmq_x
;
int
token_offs
[
mmq_x
/
nwarps
];
for
(
int
i
=
0
;
i
<
mmq_x
;
i
+=
nwarps
)
{
token_offs
[
i
/
nwarps
]
=
sorted_token_ids
[
col_dst_0
+
threadIdx
.
y
+
i
];
}
const
int
exp_idx
=
expert_ids
[
blockIdx
.
y
];
if
(
exp_idx
>
255
||
exp_idx
<
0
)
return
;
if
(
blockIdx
.
y
*
mmq_x
>
num_tokens_post_padded
[
0
])
return
;
const
block_q_t
*
x
=
(
const
block_q_t
*
)((
char
*
)
vx
+
exp_idx
*
exp_stride
);
const
block_q8_1
*
y
=
(
const
block_q8_1
*
)(
vy
);
int
*
tile_x_ql
=
nullptr
;
half2
*
tile_x_dm
=
nullptr
;
int
*
tile_x_qh
=
nullptr
;
int
*
tile_x_sc
=
nullptr
;
allocate_tiles
(
&
tile_x_ql
,
&
tile_x_dm
,
&
tile_x_qh
,
&
tile_x_sc
);
__shared__
int
tile_y_qs
[
mmq_x
*
WARP_SIZE_GGUF
];
__shared__
half2
tile_y_ds
[
mmq_x
*
WARP_SIZE_GGUF
/
QI8_1
];
float
sum
[
mmq_y
/
WARP_SIZE_GGUF
][
mmq_x
/
nwarps
]
=
{{
0.0
f
}};
for
(
int
ib0
=
0
;
ib0
<
blocks_per_row_x
;
ib0
+=
blocks_per_warp
)
{
load_tiles
(
x
+
row_x_0
*
blocks_per_row_x
+
ib0
,
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
threadIdx
.
y
,
nrows_x
-
row_x_0
-
1
,
threadIdx
.
x
,
blocks_per_row_x
);
const
int
n_per_r
=
((
qk
*
blocks_per_warp
)
/
qr
);
#pragma unroll
for
(
int
ir
=
0
;
ir
<
qr
&&
ib0
*
qk
+
ir
*
n_per_r
<
ncols_x
;
++
ir
)
{
const
auto
kqs
=
ir
*
WARP_SIZE_GGUF
+
threadIdx
.
x
;
const
int
kbxd
=
kqs
/
QI8_1
;
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_x
;
i
+=
nwarps
)
{
const
int
col_y_eff
=
token_offs
[
i
/
nwarps
]
/
top_k
;
const
int
block_x
=
ib0
*
(
qk
/
QK8_1
)
+
kbxd
;
if
(
col_y_eff
<
ncols_y
&&
block_x
<
blocks_per_col_y
)
{
const
block_q8_1
*
by0
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
block_x
];
const
int
index_y
=
(
threadIdx
.
y
+
i
)
*
WARP_SIZE_GGUF
+
kqs
%
WARP_SIZE_GGUF
;
tile_y_qs
[
index_y
]
=
get_int_from_int8_aligned
(
by0
->
qs
,
threadIdx
.
x
%
QI8_1
);
}
}
if
(
threadIdx
.
x
<
n_per_r
/
QK8_1
)
{
const
auto
kby
=
threadIdx
.
x
%
(
WARP_SIZE_GGUF
/
QI8_1
);
const
int
col_y_eff
=
token_offs
[
threadIdx
.
y
]
/
top_k
;
const
int
block_x
=
ib0
*
(
qk
/
QK8_1
)
+
ir
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
kby
;
if
(
col_y_eff
<
ncols_y
&&
block_x
<
blocks_per_col_y
)
{
const
half2
*
dsi_src
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
block_x
].
ds
;
half2
*
dsi_dst
=
&
tile_y_ds
[
threadIdx
.
y
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
kby
];
if
(
need_sum
)
{
*
dsi_dst
=
*
dsi_src
;
}
else
{
float
*
dfi_dst
=
(
float
*
)
dsi_dst
;
*
dfi_dst
=
__low2float
(
*
dsi_src
);
}
}
}
__syncthreads
();
// #pragma unroll // unrolling this loop causes too much register pressure
for
(
int
k
=
ir
*
WARP_SIZE_GGUF
/
qr
;
k
<
(
ir
+
1
)
*
WARP_SIZE_GGUF
/
qr
;
k
+=
vdr
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE_GGUF
)
{
sum
[
i
/
WARP_SIZE_GGUF
][
j
/
nwarps
]
+=
vec_dot
(
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
tile_y_qs
,
tile_y_ds
,
threadIdx
.
x
+
i
,
threadIdx
.
y
+
j
,
k
);
}
}
}
__syncthreads
();
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
const
int
col_dst
=
token_offs
[
j
/
nwarps
];
if
(
col_dst
>=
ncols_dst
)
{
return
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE_GGUF
)
{
const
auto
row_dst
=
row_dst_0
+
threadIdx
.
x
+
i
;
if
(
row_dst
>=
nrows_dst
)
{
continue
;
}
dst
[
col_dst
*
nrows_dst
+
row_dst
]
=
sum
[
i
/
WARP_SIZE_GGUF
][
j
/
nwarps
];
}
}
}
#if defined(USE_ROCM)
#define MOE_X_Q4_0 8
#define MOE_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define MOE_X_Q4_0 4
#define MOE_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_0
,
2
)
#endif
moe_q4_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q4_0
;
const
int
mmq_y
=
MOE_Y_Q4_0
;
const
int
nwarps
=
NWARPS_Q4_0
;
moe_q
<
scalar_t
,
QK4_0
,
QR4_0
,
QI4_0
,
true
,
block_q4_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_0
<
mmq_y
>
,
load_tiles_q4_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_0_Q8_1_MMQ
,
vec_dot_q4_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q4_0_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
int
mmq_x
=
MOE_X_Q4_0
;
int
mmq_y
=
MOE_Y_Q4_0
;
int
nwarps
=
NWARPS_Q4_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q4_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q4_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q4_1 8
#define MOE_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define MOE_X_Q4_1 4
#define MOE_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_1
,
2
)
#endif
moe_q4_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q4_1
;
const
int
mmq_y
=
MOE_Y_Q4_1
;
const
int
nwarps
=
NWARPS_Q4_1
;
moe_q
<
scalar_t
,
QK4_1
,
QR4_1
,
QI4_1
,
true
,
block_q4_1
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_1
<
mmq_y
>
,
load_tiles_q4_1
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_1_Q8_1_MMQ
,
vec_dot_q4_1_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q4_1_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
int
mmq_x
=
MOE_X_Q4_1
;
int
mmq_y
=
MOE_Y_Q4_1
;
int
nwarps
=
NWARPS_Q4_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q4_1
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q4_1
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q5_0 8
#define MOE_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define MOE_X_Q5_0 4
#define MOE_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_0
,
2
)
#endif
moe_q5_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q5_0
;
const
int
mmq_y
=
MOE_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
moe_q
<
scalar_t
,
QK5_0
,
QR5_0
,
QI5_0
,
false
,
block_q5_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_0
<
mmq_y
>
,
load_tiles_q5_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_0_Q8_1_MMQ
,
vec_dot_q5_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q5_0_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MOE_X_Q5_0
;
const
int
mmq_y
=
MOE_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q5_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q5_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q5_1 8
#define MOE_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define MOE_X_Q5_1 4
#define MOE_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_1
,
2
)
#endif
moe_q5_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q5_1
;
const
int
mmq_y
=
MOE_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
moe_q
<
scalar_t
,
QK5_1
,
QR5_1
,
QI5_1
,
true
,
block_q5_1
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_1
<
mmq_y
>
,
load_tiles_q5_1
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_1_Q8_1_MMQ
,
vec_dot_q5_1_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q5_1_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MOE_X_Q5_1
;
const
int
mmq_y
=
MOE_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q5_1
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q5_1
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q8_0 8
#define MOE_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define MOE_X_Q8_0 4
#define MOE_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q8_0
,
2
)
#endif
moe_q8_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q8_0
;
const
int
mmq_y
=
MOE_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
moe_q
<
scalar_t
,
QK8_0
,
QR8_0
,
QI8_0
,
false
,
block_q8_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q8_0
<
mmq_y
>
,
load_tiles_q8_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q8_0_Q8_1_MMQ
,
vec_dot_q8_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q8_0_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MOE_X_Q8_0
;
const
int
mmq_y
=
MOE_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q8_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q8_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q2_K 8
#define MOE_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define MOE_X_Q2_K 4
#define MOE_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q2_K
,
2
)
#endif
moe_q2_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q2_K
;
const
int
mmq_y
=
MOE_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
moe_q
<
scalar_t
,
QK_K
,
QR2_K
,
QI2_K
,
false
,
block_q2_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q2_K
<
mmq_y
>
,
load_tiles_q2_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q2_K_Q8_1_MMQ
,
vec_dot_q2_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q2_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MOE_X_Q2_K
;
const
int
mmq_y
=
MOE_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q2_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q2_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q3_K 8
#define MOE_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define MOE_X_Q3_K 4
#define MOE_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q3_K
,
2
)
#endif
moe_q3_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q3_K
;
const
int
mmq_y
=
MOE_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
moe_q
<
scalar_t
,
QK_K
,
QR3_K
,
QI3_K
,
false
,
block_q3_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q3_K
<
mmq_y
>
,
load_tiles_q3_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q3_K_Q8_1_MMQ
,
vec_dot_q3_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q3_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MOE_X_Q3_K
;
const
int
mmq_y
=
MOE_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q3_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q3_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q4_K 8
#define MOE_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define MOE_X_Q4_K 4
#define MOE_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_K
,
2
)
#endif
moe_q4_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q4_K
;
const
int
mmq_y
=
MOE_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
moe_q
<
scalar_t
,
QK_K
,
QR4_K
,
QI4_K
,
true
,
block_q4_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_K
<
mmq_y
>
,
load_tiles_q4_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_K_Q8_1_MMQ
,
vec_dot_q4_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q4_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MOE_X_Q4_K
;
const
int
mmq_y
=
MOE_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q4_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q4_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q5_K 8
#define MOE_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define MOE_X_Q5_K 4
#define MOE_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_K
,
2
)
#endif
moe_q5_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q5_K
;
const
int
mmq_y
=
MOE_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
moe_q
<
scalar_t
,
QK_K
,
QR5_K
,
QI5_K
,
true
,
block_q5_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_K
<
mmq_y
>
,
load_tiles_q5_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_K_Q8_1_MMQ
,
vec_dot_q5_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q5_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MOE_X_Q5_K
;
const
int
mmq_y
=
MOE_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q5_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q5_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q6_K 8
#define MOE_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define MOE_X_Q6_K 4
#define MOE_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q6_K
,
2
)
#endif
moe_q6_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MOE_X_Q6_K
;
const
int
mmq_y
=
MOE_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
moe_q
<
scalar_t
,
QK_K
,
QR6_K
,
QI6_K
,
false
,
block_q6_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q6_K
<
mmq_y
>
,
load_tiles_q6_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q6_K_Q8_1_MMQ
,
vec_dot_q6_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q6_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MOE_X_Q6_K
;
const
int
mmq_y
=
MOE_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q6_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q6_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
sgl-kernel/csrc/quantization/gguf/moe_vec.cuh
0 → 100644
View file @
8fdcd98e
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/moe_vec.cuh
// copied and adapted from
// https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template
<
typename
scalar_t
,
int
qk
,
int
qi
,
typename
block_q_t
,
int
vdr
,
vec_dot_q_cuda_t
vec_dot_q_cuda
>
static
__global__
void
moe_vec_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
topk_ids
,
const
int
topk
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
)
{
const
auto
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
auto
token
=
blockIdx
.
z
/
topk
;
const
auto
expert
=
(
topk_ids
)[
blockIdx
.
z
];
if
(
row
>=
nrows
)
{
return
;
}
const
int
blocks_per_row
=
ncols
/
qk
;
const
int
blocks_per_warp
=
vdr
*
WARP_SIZE
/
qi
;
// partial sum for each thread
float
tmp
=
0.0
f
;
const
block_q_t
*
x
=
((
const
block_q_t
*
)
vx
)
+
expert
*
nrows
*
blocks_per_row
;
const
block_q8_1
*
y
=
(
const
block_q8_1
*
)(((
const
int
*
)
vy
)
+
token
*
token_stride
);
for
(
auto
i
=
threadIdx
.
x
/
(
qi
/
vdr
);
i
<
blocks_per_row
;
i
+=
blocks_per_warp
)
{
const
int
ibx
=
row
*
blocks_per_row
+
i
;
// x block index
const
int
iby
=
i
*
(
qk
/
QK8_1
);
// y block index that aligns with ibx
const
int
iqs
=
vdr
*
(
threadIdx
.
x
%
(
qi
/
vdr
));
// x block quant index when casting the quants to int
tmp
+=
vec_dot_q_cuda
(
&
x
[
ibx
],
&
y
[
iby
],
iqs
);
}
// sum up partial sums and write back result
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>
0
;
mask
>>=
1
)
{
tmp
+=
SGLANG_SHFL_XOR_SYNC
(
uint32_t
(
-
1
),
tmp
,
mask
);
}
if
(
threadIdx
.
x
==
0
)
{
dst
[
blockIdx
.
z
*
nrows
+
row
]
=
tmp
;
}
}
template
<
typename
scalar_t
>
static
void
moe_vec_q4_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK4_0
,
QI4_0
,
block_q4_0
,
VDR_Q4_0_Q8_1_MMVQ
,
vec_dot_q4_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q4_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK4_0
,
QI4_1
,
block_q4_1
,
VDR_Q4_1_Q8_1_MMVQ
,
vec_dot_q4_1_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q5_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK5_0
,
QI5_0
,
block_q5_0
,
VDR_Q5_0_Q8_1_MMVQ
,
vec_dot_q5_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q5_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK5_1
,
QI5_1
,
block_q5_1
,
VDR_Q5_1_Q8_1_MMVQ
,
vec_dot_q5_1_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q8_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK8_0
,
QI8_0
,
block_q8_0
,
VDR_Q8_0_Q8_1_MMVQ
,
vec_dot_q8_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q2_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI2_K
,
block_q2_K
,
VDR_Q2_K_Q8_1_MMVQ
,
vec_dot_q2_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q3_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI3_K
,
block_q3_K
,
VDR_Q3_K_Q8_1_MMVQ
,
vec_dot_q3_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q4_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI4_K
,
block_q4_K
,
VDR_Q4_K_Q8_1_MMVQ
,
vec_dot_q4_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q5_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI5_K
,
block_q5_K
,
VDR_Q5_K_Q8_1_MMVQ
,
vec_dot_q5_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_q6_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI6_K
,
block_q6_K
,
VDR_Q6_K_Q8_1_MMVQ
,
vec_dot_q6_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq2_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI2_XXS
,
block_iq2_xxs
,
1
,
vec_dot_iq2_xxs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq2_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI2_XS
,
block_iq2_xs
,
1
,
vec_dot_iq2_xs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq2_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI2_S
,
block_iq2_s
,
1
,
vec_dot_iq2_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq3_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI3_XXS
,
block_iq3_xxs
,
1
,
vec_dot_iq3_xxs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq1_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI1_S
,
block_iq1_s
,
1
,
vec_dot_iq1_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq1_m_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI1_M
,
block_iq1_m
,
1
,
vec_dot_iq1_m_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq4_nl_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK4_NL
,
QI4_NL
,
block_iq4_nl
,
VDR_Q4_0_Q8_1_MMVQ
,
vec_dot_iq4_nl_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq4_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI4_XS
,
block_iq4_xs
,
1
,
vec_dot_iq4_xs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
template
<
typename
scalar_t
>
static
void
moe_vec_iq3_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
scalar_t
*
dst
,
const
int
*
topk_ids
,
const
int
top_k
,
const
int
tokens
,
const
int
ncols
,
const
int
nrows
,
const
int
token_stride
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
tokens
*
top_k
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
moe_vec_q
<
scalar_t
,
QK_K
,
QI3_XS
,
block_iq3_s
,
1
,
vec_dot_iq3_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
topk_ids
,
top_k
,
ncols
,
nrows
,
token_stride
);
}
sgl-kernel/csrc/quantization/gguf/vecdotq.cuh
0 → 100644
View file @
8fdcd98e
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/vecdotq.cuh
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh
// and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
static
__device__
__forceinline__
int
get_int_b2
(
const
void
*
x
,
const
int
&
i32
)
{
const
uint16_t
*
x16
=
(
const
uint16_t
*
)
x
;
// assume at least 2 byte alignment
int
x32
=
x16
[
2
*
i32
+
0
]
<<
0
;
x32
|=
x16
[
2
*
i32
+
1
]
<<
16
;
return
x32
;
}
static
__device__
__forceinline__
int
get_int_b4
(
const
void
*
x
,
const
int
&
i32
)
{
return
((
const
int
*
)
x
)[
i32
];
// assume at least 4 byte alignment
}
static
__device__
__forceinline__
int
get_int_from_int8
(
const
int8_t
*
x8
,
const
int
&
i32
)
{
const
uint16_t
*
x16
=
(
const
uint16_t
*
)(
x8
+
sizeof
(
int
)
*
i32
);
// assume at least 2 byte alignment
int
x32
=
0
;
x32
|=
x16
[
0
]
<<
0
;
x32
|=
x16
[
1
]
<<
16
;
return
x32
;
}
static
__device__
__forceinline__
int
get_int_from_uint8
(
const
uint8_t
*
x8
,
const
int
&
i32
)
{
const
uint16_t
*
x16
=
(
const
uint16_t
*
)(
x8
+
sizeof
(
int
)
*
i32
);
// assume at least 2 byte alignment
int
x32
=
0
;
x32
|=
x16
[
0
]
<<
0
;
x32
|=
x16
[
1
]
<<
16
;
return
x32
;
}
static
__device__
__forceinline__
int
get_int_from_int8_aligned
(
const
int8_t
*
x8
,
const
int
&
i32
)
{
return
*
((
const
int
*
)(
x8
+
sizeof
(
int
)
*
i32
));
// assume at least 4 byte alignment
}
static
__device__
__forceinline__
int
get_int_from_uint8_aligned
(
const
uint8_t
*
x8
,
const
int
&
i32
)
{
return
*
((
const
int
*
)(
x8
+
sizeof
(
int
)
*
i32
));
// assume at least 4 byte alignment
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q4_0_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
float
&
d4
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
const
int
vi0
=
(
v
[
i
]
>>
0
)
&
0x0F0F0F0F
;
const
int
vi1
=
(
v
[
i
]
>>
4
)
&
0x0F0F0F0F
;
// SIMD dot product of quantized values
sumi
=
__dp4a
(
vi0
,
u
[
2
*
i
+
0
],
sumi
);
sumi
=
__dp4a
(
vi1
,
u
[
2
*
i
+
1
],
sumi
);
}
const
float2
ds8f
=
__half22float2
(
ds8
);
// second part effectively subtracts 8 from each quant value
return
d4
*
(
sumi
*
ds8f
.
x
-
(
8
*
vdr
/
QI4_0
)
*
ds8f
.
y
);
#endif
}
#define VDR_Q4_1_Q8_1_MMVQ 2
#define VDR_Q4_1_Q8_1_MMQ 4
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q4_1_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
half2
&
dm4
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
const
int
vi0
=
(
v
[
i
]
>>
0
)
&
0x0F0F0F0F
;
const
int
vi1
=
(
v
[
i
]
>>
4
)
&
0x0F0F0F0F
;
// SIMD dot product of quantized values
sumi
=
__dp4a
(
vi0
,
u
[
2
*
i
+
0
],
sumi
);
sumi
=
__dp4a
(
vi1
,
u
[
2
*
i
+
1
],
sumi
);
}
const
float2
tmp
=
__half22float2
(
__hmul2
(
dm4
,
ds8
));
const
float
d4d8
=
tmp
.
x
;
const
float
m4s8
=
tmp
.
y
;
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
return
sumi
*
d4d8
+
m4s8
/
(
QI8_1
/
(
vdr
*
QR4_1
));
#endif
}
#define VDR_Q5_0_Q8_1_MMVQ 2
#define VDR_Q5_0_Q8_1_MMQ 4
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q5_0_q8_1_impl
(
const
int
*
vl
,
const
int
*
vh
,
const
int
*
u
,
const
float
&
d5
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
int
vi0
=
(
vl
[
i
]
>>
0
)
&
0x0F0F0F0F
;
// lower 4 qs bits, still need qh as 5th bits
vi0
|=
(
vh
[
i
]
<<
4
)
&
0x00000010
;
// 0 -> 4
vi0
|=
(
vh
[
i
]
<<
11
)
&
0x00001000
;
// 1 -> 12
vi0
|=
(
vh
[
i
]
<<
18
)
&
0x00100000
;
// 2 -> 20
vi0
|=
(
vh
[
i
]
<<
25
)
&
0x10000000
;
// 3 -> 28
sumi
=
__dp4a
(
vi0
,
u
[
2
*
i
+
0
],
sumi
);
// SIMD dot product of quantized values
int
vi1
=
(
vl
[
i
]
>>
4
)
&
0x0F0F0F0F
;
// upper 4 qs bits, still need qh as 5th bits
vi1
|=
(
vh
[
i
]
>>
12
)
&
0x00000010
;
// 16 -> 4
vi1
|=
(
vh
[
i
]
>>
5
)
&
0x00001000
;
// 17 -> 12
vi1
|=
(
vh
[
i
]
<<
2
)
&
0x00100000
;
// 18 -> 20
vi1
|=
(
vh
[
i
]
<<
9
)
&
0x10000000
;
// 19 -> 28
sumi
=
__dp4a
(
vi1
,
u
[
2
*
i
+
1
],
sumi
);
// SIMD dot product of quantized values
}
const
float2
ds8f
=
__half22float2
(
ds8
);
// second part effectively subtracts 16 from each quant value
return
d5
*
(
sumi
*
ds8f
.
x
-
(
16
*
vdr
/
QI5_0
)
*
ds8f
.
y
);
#endif
}
#define VDR_Q5_1_Q8_1_MMVQ 2
#define VDR_Q5_1_Q8_1_MMQ 4
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q5_1_q8_1_impl
(
const
int
*
vl
,
const
int
*
vh
,
const
int
*
u
,
const
half2
&
dm5
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
int
vi0
=
(
vl
[
i
]
>>
0
)
&
0x0F0F0F0F
;
// lower 4 qs bits, still need qh as 5th bits
vi0
|=
(
vh
[
i
]
<<
4
)
&
0x00000010
;
// 0 -> 4
vi0
|=
(
vh
[
i
]
<<
11
)
&
0x00001000
;
// 1 -> 12
vi0
|=
(
vh
[
i
]
<<
18
)
&
0x00100000
;
// 2 -> 20
vi0
|=
(
vh
[
i
]
<<
25
)
&
0x10000000
;
// 3 -> 28
sumi
=
__dp4a
(
vi0
,
u
[
2
*
i
+
0
],
sumi
);
// SIMD dot product of quantized values
int
vi1
=
(
vl
[
i
]
>>
4
)
&
0x0F0F0F0F
;
// upper 4 qs bits, still need qh as 5th bits
vi1
|=
(
vh
[
i
]
>>
12
)
&
0x00000010
;
// 16 -> 4
vi1
|=
(
vh
[
i
]
>>
5
)
&
0x00001000
;
// 17 -> 12
vi1
|=
(
vh
[
i
]
<<
2
)
&
0x00100000
;
// 18 -> 20
vi1
|=
(
vh
[
i
]
<<
9
)
&
0x10000000
;
// 19 -> 28
sumi
=
__dp4a
(
vi1
,
u
[
2
*
i
+
1
],
sumi
);
// SIMD dot product of quantized values
}
const
float2
tmp
=
__half22float2
(
__hmul2
(
dm5
,
ds8
));
const
float
d5d8
=
tmp
.
x
;
const
float
m5s8
=
tmp
.
y
;
// scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
return
sumi
*
d5d8
+
m5s8
/
(
QI5_1
/
vdr
);
#endif
}
#define VDR_Q8_0_Q8_1_MMVQ 2
#define VDR_Q8_0_Q8_1_MMQ 8
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q8_0_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
float
&
d8_0
,
const
float
&
d8_1
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
// SIMD dot product of quantized values
sumi
=
__dp4a
(
v
[
i
],
u
[
i
],
sumi
);
}
return
d8_0
*
d8_1
*
sumi
;
#endif
}
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q8_1_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
half2
&
dm8
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
// SIMD dot product of quantized values
sumi
=
__dp4a
(
v
[
i
],
u
[
i
],
sumi
);
}
const
float2
tmp
=
__half22float2
(
__hmul2
(
dm8
,
ds8
));
const
float
d8d8
=
tmp
.
x
;
const
float
m8s8
=
tmp
.
y
;
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
return
sumi
*
d8d8
+
m8s8
/
(
QI8_1
/
vdr
);
#endif
}
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 2
// contiguous v/x values
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1_impl_mmvq
(
const
int
&
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
half2
&
dm2
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR2_K
;
++
i
)
{
const
int
sc
=
scales
[
2
*
i
];
const
int
vi
=
(
v
>>
(
2
*
i
))
&
0x03030303
;
sumf_d
+=
d8
[
i
]
*
(
__dp4a
(
vi
,
u
[
i
],
0
)
*
(
sc
&
0xF
));
// SIMD dot product
// fill int with 4x m
int
m
=
sc
>>
4
;
m
|=
m
<<
8
;
m
|=
m
<<
16
;
sumf_m
+=
d8
[
i
]
*
__dp4a
(
m
,
u
[
i
],
0
);
// multiply constant q2_K part with sum of q8_1 values
}
const
float2
dm2f
=
__half22float2
(
dm2
);
return
dm2f
.
x
*
sumf_d
-
dm2f
.
y
*
sumf_m
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
half2
&
dm2
,
const
float
&
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int
sumi_d
=
0
;
int
sumi_m
=
0
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
QI8_1
;
i0
+=
QI8_1
/
2
)
{
int
sumi_d_sc
=
0
;
const
int
sc
=
scales
[
i0
/
(
QI8_1
/
2
)];
// fill int with 4x m
int
m
=
sc
>>
4
;
m
|=
m
<<
8
;
m
|=
m
<<
16
;
#pragma unroll
for
(
int
i
=
i0
;
i
<
i0
+
QI8_1
/
2
;
++
i
)
{
sumi_d_sc
=
__dp4a
(
v
[
i
],
u
[
i
],
sumi_d_sc
);
// SIMD dot product
sumi_m
=
__dp4a
(
m
,
u
[
i
],
sumi_m
);
// multiply sum of q8_1 values with m
}
sumi_d
+=
sumi_d_sc
*
(
sc
&
0xF
);
}
const
float2
dm2f
=
__half22float2
(
dm2
);
return
d8
*
(
dm2f
.
x
*
sumi_d
-
dm2f
.
y
*
sumi_m
);
#endif
}
#define VDR_Q3_K_Q8_1_MMVQ 1
#define VDR_Q3_K_Q8_1_MMQ 2
// contiguous v/x values
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1_impl_mmvq
(
const
int
&
vl
,
const
int
&
vh
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
int
&
scale_offset
,
const
float
&
d3
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float
sumf
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR3_K
;
++
i
)
{
const
int
isc
=
scale_offset
+
2
*
i
;
const
int
isc_low
=
isc
%
(
QK_K
/
32
);
const
int
sc_shift_low
=
4
*
(
isc
/
(
QK_K
/
32
));
const
int
sc_low
=
(
scales
[
isc_low
]
>>
sc_shift_low
)
&
0xF
;
const
int
isc_high
=
isc
%
(
QK_K
/
64
);
const
int
sc_shift_high
=
2
*
(
isc
/
(
QK_K
/
64
));
const
int
sc_high
=
((
scales
[(
QK_K
/
32
)
+
isc_high
]
>>
sc_shift_high
)
&
3
)
<<
4
;
const
int
sc
=
(
sc_low
|
sc_high
)
-
32
;
const
int
vil
=
(
vl
>>
(
2
*
i
))
&
0x03030303
;
const
int
vih
=
((
vh
>>
i
)
<<
2
)
&
0x04040404
;
const
int
vi
=
__vsubss4
(
vil
,
vih
);
sumf
+=
d8
[
i
]
*
(
__dp4a
(
vi
,
u
[
i
],
0
)
*
sc
);
// SIMD dot product
}
return
d3
*
sumf
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
scales
,
const
float
&
d3
,
const
float
&
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int
sumi
=
0
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
QR3_K
*
VDR_Q3_K_Q8_1_MMQ
;
i0
+=
QI8_1
/
2
)
{
int
sumi_sc
=
0
;
for
(
int
i
=
i0
;
i
<
i0
+
QI8_1
/
2
;
++
i
)
{
sumi_sc
=
__dp4a
(
v
[
i
],
u
[
i
],
sumi_sc
);
// SIMD dot product
}
sumi
+=
sumi_sc
*
scales
[
i0
/
(
QI8_1
/
2
)];
}
return
d3
*
d8
*
sumi
;
#endif
}
#define VDR_Q4_K_Q8_1_MMVQ 2
#define VDR_Q4_K_Q8_1_MMQ 8
// contiguous v/x values
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1_impl_vmmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR4_K
;
++
i
)
{
const
int
v0i
=
(
v
[
0
]
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
v1i
=
(
v
[
1
]
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
dot1
=
__dp4a
(
v1i
,
u
[
2
*
i
+
1
],
__dp4a
(
v0i
,
u
[
2
*
i
+
0
],
0
));
// SIMD dot product
const
int
dot2
=
__dp4a
(
0x01010101
,
u
[
2
*
i
+
1
],
__dp4a
(
0x01010101
,
u
[
2
*
i
+
0
],
0
));
// sum of u
sumf_d
+=
d8
[
i
]
*
(
dot1
*
sc
[
i
]);
sumf_m
+=
d8
[
i
]
*
(
dot2
*
m
[
i
]);
// multiply constant part of q4_K with sum of q8_1 values
}
const
float2
dm4f
=
__half22float2
(
dm4
);
return
dm4f
.
x
*
sumf_d
-
dm4f
.
y
*
sumf_m
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
half2
*
__restrict__
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR4_K
*
VDR_Q4_K_Q8_1_MMQ
/
QI8_1
;
++
i
)
{
int
sumi_d
=
0
;
#pragma unroll
for
(
int
j
=
0
;
j
<
QI8_1
;
++
j
)
{
sumi_d
=
__dp4a
((
v
[
j
]
>>
(
4
*
i
))
&
0x0F0F0F0F
,
u
[
i
*
QI8_1
+
j
],
sumi_d
);
// SIMD dot product
}
const
float2
ds8f
=
__half22float2
(
ds8
[
i
]);
sumf_d
+=
ds8f
.
x
*
(
sc
[
i
]
*
sumi_d
);
sumf_m
+=
ds8f
.
y
*
m
[
i
];
// sum of q8_1 block * q4_K min val
}
const
float2
dm4f
=
__half22float2
(
dm4
);
return
dm4f
.
x
*
sumf_d
-
dm4f
.
y
*
sumf_m
;
#endif
}
#define VDR_Q5_K_Q8_1_MMVQ 2
#define VDR_Q5_K_Q8_1_MMQ 8
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_impl_vmmq
(
const
int
*
__restrict__
vl
,
const
int
*
__restrict__
vh
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm5
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR5_K
;
++
i
)
{
const
int
vl0i
=
(
vl
[
0
]
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
vl1i
=
(
vl
[
1
]
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
vh0i
=
((
vh
[
0
]
>>
i
)
<<
4
)
&
0x10101010
;
const
int
vh1i
=
((
vh
[
1
]
>>
i
)
<<
4
)
&
0x10101010
;
const
int
v0i
=
vl0i
|
vh0i
;
const
int
v1i
=
vl1i
|
vh1i
;
const
int
dot1
=
__dp4a
(
v0i
,
u
[
2
*
i
+
0
],
__dp4a
(
v1i
,
u
[
2
*
i
+
1
],
0
));
// SIMD dot product
const
int
dot2
=
__dp4a
(
0x01010101
,
u
[
2
*
i
+
0
],
__dp4a
(
0x01010101
,
u
[
2
*
i
+
1
],
0
));
// sum of u
sumf_d
+=
d8
[
i
]
*
(
dot1
*
sc
[
i
]);
sumf_m
+=
d8
[
i
]
*
(
dot2
*
m
[
i
]);
}
const
float2
dm5f
=
__half22float2
(
dm5
);
return
dm5f
.
x
*
sumf_d
-
dm5f
.
y
*
sumf_m
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
half2
*
__restrict__
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR5_K
*
VDR_Q5_K_Q8_1_MMQ
/
QI8_1
;
++
i
)
{
int
sumi_d
=
0
;
#pragma unroll
for
(
int
j
=
0
;
j
<
QI8_1
;
++
j
)
{
sumi_d
=
__dp4a
(
v
[
i
*
QI8_1
+
j
],
u
[
i
*
QI8_1
+
j
],
sumi_d
);
// SIMD dot product
}
const
float2
ds8f
=
__half22float2
(
ds8
[
i
]);
sumf_d
+=
ds8f
.
x
*
(
sc
[
i
]
*
sumi_d
);
sumf_m
+=
ds8f
.
y
*
m
[
i
];
// sum of q8_1 block * q4_K min val
}
const
float2
dm4f
=
__half22float2
(
dm4
);
return
dm4f
.
x
*
sumf_d
-
dm4f
.
y
*
sumf_m
;
#endif
}
#define VDR_Q6_K_Q8_1_MMVQ 1
#define VDR_Q6_K_Q8_1_MMQ 8
// contiguous v/x values
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1_impl_mmvq
(
const
int
&
vl
,
const
int
&
vh
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
scales
,
const
float
&
d
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float
sumf
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR6_K
;
++
i
)
{
const
int
sc
=
scales
[
4
*
i
];
const
int
vil
=
(
vl
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
vih
=
((
vh
>>
(
4
*
i
))
<<
4
)
&
0x30303030
;
const
int
vi
=
__vsubss4
((
vil
|
vih
),
0x20202020
);
// vi = (vil | vih) - 32
sumf
+=
d8
[
i
]
*
(
__dp4a
(
vi
,
u
[
i
],
0
)
*
sc
);
// SIMD dot product
}
return
d
*
sumf
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
sc
,
const
float
&
d6
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float
sumf_d
=
0.0
f
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
VDR_Q6_K_Q8_1_MMQ
;
i0
+=
4
)
{
int2
sumi_d
=
{
0
,
0
};
// 2 q6_K scales per q8_1 scale
#pragma unroll
for
(
int
i
=
i0
;
i
<
i0
+
2
;
++
i
)
{
sumi_d
.
x
=
__dp4a
(
v
[
2
*
i
+
0
],
u
[
2
*
i
+
0
],
sumi_d
.
x
);
// SIMD dot product
sumi_d
.
x
=
__dp4a
(
v
[
2
*
i
+
1
],
u
[
2
*
i
+
1
],
sumi_d
.
x
);
// SIMD dot product
sumi_d
.
y
=
__dp4a
(
v
[
2
*
i
+
4
],
u
[
2
*
i
+
4
],
sumi_d
.
y
);
// SIMD dot product
sumi_d
.
y
=
__dp4a
(
v
[
2
*
i
+
5
],
u
[
2
*
i
+
5
],
sumi_d
.
y
);
// SIMD dot product
}
sumf_d
+=
d8
[
i0
/
4
]
*
(
sc
[
i0
/
2
+
0
]
*
sumi_d
.
x
+
sc
[
i0
/
2
+
1
]
*
sumi_d
.
y
);
}
return
d6
*
sumf_d
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q4_0_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q4_0
*
bq4_0
=
(
const
block_q4_0
*
)
vbq
;
int
v
[
VDR_Q4_0_Q8_1_MMVQ
];
int
u
[
2
*
VDR_Q4_0_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q4_0_Q8_1_MMVQ
;
++
i
)
{
v
[
i
]
=
get_int_from_uint8
(
bq4_0
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
0
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
1
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
+
QI4_0
);
}
return
vec_dot_q4_0_q8_1_impl
<
VDR_Q4_0_Q8_1_MMVQ
>
(
v
,
u
,
__half2float
(
bq4_0
->
d
),
bq8_1
->
ds
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI4_0
)
+
mmq_y
/
QI4_0
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q4_0
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI4_0
;
const
int
kqsx
=
k
%
QI4_0
;
const
block_q4_0
*
bx0
=
(
const
block_q4_0
*
)
vx
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
]
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
// x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbx] = bxi->d;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI4_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI4_0
)
{
int
i
=
i0
+
i_offset
*
QI4_0
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI4_0
)
+
i
/
QI4_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
static
__device__
__forceinline__
float
vec_dot_q4_0_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
(
void
)
x_qh
;
(
void
)
x_sc
;
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
int
u
[
2
*
VDR_Q4_0_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q4_0_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE_GGUF
+
(
kyqs
+
l
)
%
WARP_SIZE_GGUF
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE_GGUF
+
(
kyqs
+
l
+
QI4_0
)
%
WARP_SIZE_GGUF
];
}
return
vec_dot_q4_0_q8_1_impl
<
VDR_Q4_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
],
u
,
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI4_0
)
+
i
/
QI4_0
+
k
/
QI4_0
],
y_ds
[
j
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE_GGUF
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q4_1_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q4_1
*
bq4_1
=
(
const
block_q4_1
*
)
vbq
;
int
v
[
VDR_Q4_1_Q8_1_MMVQ
];
int
u
[
2
*
VDR_Q4_1_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q4_1_Q8_1_MMVQ
;
++
i
)
{
v
[
i
]
=
get_int_from_uint8_aligned
(
bq4_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
0
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
1
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
+
QI4_1
);
}
return
vec_dot_q4_1_q8_1_impl
<
VDR_Q4_1_Q8_1_MMVQ
>
(
v
,
u
,
bq4_1
->
dm
,
bq8_1
->
ds
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_1
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE_GGUF
)
+
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI4_1
)
+
mmq_y
/
QI4_1
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
tile_x_dm
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q4_1
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI4_1
;
const
int
kqsx
=
k
%
QI4_1
;
const
block_q4_1
*
bx0
=
(
const
block_q4_1
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI4_1
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI4_1
)
{
int
i
=
i0
+
i_offset
*
QI4_1
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI4_1
)
+
i
/
QI4_1
+
kbxd
]
=
bxi
->
dm
;
}
}
static
__device__
__forceinline__
float
vec_dot_q4_1_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
int
u
[
2
*
VDR_Q4_1_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q4_1_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE_GGUF
+
(
kyqs
+
l
)
%
WARP_SIZE_GGUF
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE_GGUF
+
(
kyqs
+
l
+
QI4_1
)
%
WARP_SIZE_GGUF
];
}
return
vec_dot_q4_1_q8_1_impl
<
VDR_Q4_1_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
],
u
,
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI4_1
)
+
i
/
QI4_1
+
k
/
QI4_1
],
y_ds
[
j
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE_GGUF
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q5_0_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q5_0
*
bq5_0
=
(
const
block_q5_0
*
)
vbq
;
int
vl
[
VDR_Q5_0_Q8_1_MMVQ
];
int
vh
[
VDR_Q5_0_Q8_1_MMVQ
];
int
u
[
2
*
VDR_Q5_0_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q5_0_Q8_1_MMVQ
;
++
i
)
{
vl
[
i
]
=
get_int_from_uint8
(
bq5_0
->
qs
,
iqs
+
i
);
vh
[
i
]
=
get_int_from_uint8
(
bq5_0
->
qh
,
0
)
>>
(
4
*
(
iqs
+
i
));
u
[
2
*
i
+
0
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
1
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
+
QI5_0
);
}
return
vec_dot_q5_0_q8_1_impl
<
VDR_Q5_0_Q8_1_MMVQ
>
(
vl
,
vh
,
u
,
__half2float
(
bq5_0
->
d
),
bq8_1
->
ds
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI5_0
)
+
mmq_y
/
QI5_0
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q5_0
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI5_0
;
const
int
kqsx
=
k
%
QI5_0
;
const
block_q5_0
*
bx0
=
(
const
block_q5_0
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
const
int
ql
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
const
int
qh
=
get_int_from_uint8
(
bxi
->
qh
,
0
)
>>
(
4
*
(
k
%
QI5_0
));
int
qs0
=
(
ql
>>
0
)
&
0x0F0F0F0F
;
qs0
|=
(
qh
<<
4
)
&
0x00000010
;
// 0 -> 4
qs0
|=
(
qh
<<
11
)
&
0x00001000
;
// 1 -> 12
qs0
|=
(
qh
<<
18
)
&
0x00100000
;
// 2 -> 20
qs0
|=
(
qh
<<
25
)
&
0x10000000
;
// 3 -> 28
qs0
=
__vsubss4
(
qs0
,
0x10101010
);
// subtract 16
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
2
*
k
+
0
]
=
qs0
;
int
qs1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
qs1
|=
(
qh
>>
12
)
&
0x00000010
;
// 16 -> 4
qs1
|=
(
qh
>>
5
)
&
0x00001000
;
// 17 -> 12
qs1
|=
(
qh
<<
2
)
&
0x00100000
;
// 18 -> 20
qs1
|=
(
qh
<<
9
)
&
0x10000000
;
// 19 -> 28
qs1
=
__vsubss4
(
qs1
,
0x10101010
);
// subtract 16
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
2
*
k
+
1
]
=
qs1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI5_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI5_0
)
{
int
i
=
i0
+
i_offset
*
QI5_0
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI5_0
)
+
i
/
QI5_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
static
__device__
__forceinline__
float
vec_dot_q5_0_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
const
int
index_bx
=
i
*
(
WARP_SIZE_GGUF
/
QI5_0
)
+
i
/
QI5_0
+
k
/
QI5_0
;
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
int
u
[
2
*
VDR_Q5_0_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q5_0_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE_GGUF
+
(
kyqs
+
l
)
%
WARP_SIZE_GGUF
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE_GGUF
+
(
kyqs
+
l
+
QI5_0
)
%
WARP_SIZE_GGUF
];
}
return
vec_dot_q8_0_q8_1_impl
<
QR5_0
*
VDR_Q5_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
2
*
k
],
u
,
x_dmf
[
index_bx
],
y_df
[
j
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE_GGUF
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q5_1_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q5_1
*
bq5_1
=
(
const
block_q5_1
*
)
vbq
;
int
vl
[
VDR_Q5_1_Q8_1_MMVQ
];
int
vh
[
VDR_Q5_1_Q8_1_MMVQ
];
int
u
[
2
*
VDR_Q5_1_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q5_1_Q8_1_MMVQ
;
++
i
)
{
vl
[
i
]
=
get_int_from_uint8_aligned
(
bq5_1
->
qs
,
iqs
+
i
);
vh
[
i
]
=
get_int_from_uint8_aligned
(
bq5_1
->
qh
,
0
)
>>
(
4
*
(
iqs
+
i
));
u
[
2
*
i
+
0
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
1
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
+
QI5_1
);
}
return
vec_dot_q5_1_q8_1_impl
<
VDR_Q5_1_Q8_1_MMVQ
>
(
vl
,
vh
,
u
,
bq5_1
->
dm
,
bq8_1
->
ds
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_1
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI5_1
)
+
mmq_y
/
QI5_1
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q5_1
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI5_1
;
const
int
kqsx
=
k
%
QI5_1
;
const
block_q5_1
*
bx0
=
(
const
block_q5_1
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
const
int
ql
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
const
int
qh
=
get_int_from_uint8_aligned
(
bxi
->
qh
,
0
)
>>
(
4
*
(
k
%
QI5_1
));
int
qs0
=
(
ql
>>
0
)
&
0x0F0F0F0F
;
qs0
|=
(
qh
<<
4
)
&
0x00000010
;
// 0 -> 4
qs0
|=
(
qh
<<
11
)
&
0x00001000
;
// 1 -> 12
qs0
|=
(
qh
<<
18
)
&
0x00100000
;
// 2 -> 20
qs0
|=
(
qh
<<
25
)
&
0x10000000
;
// 3 -> 28
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
2
*
k
+
0
]
=
qs0
;
int
qs1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
qs1
|=
(
qh
>>
12
)
&
0x00000010
;
// 16 -> 4
qs1
|=
(
qh
>>
5
)
&
0x00001000
;
// 17 -> 12
qs1
|=
(
qh
<<
2
)
&
0x00100000
;
// 18 -> 20
qs1
|=
(
qh
<<
9
)
&
0x10000000
;
// 19 -> 28
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
2
*
k
+
1
]
=
qs1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI5_1
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI5_1
)
{
int
i
=
i0
+
i_offset
*
QI5_1
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI5_1
)
+
i
/
QI5_1
+
kbxd
]
=
bxi
->
dm
;
}
}
static
__device__
__forceinline__
float
vec_dot_q5_1_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
const
int
index_bx
=
i
*
(
WARP_SIZE_GGUF
/
QI5_1
)
+
+
i
/
QI5_1
+
k
/
QI5_1
;
int
u
[
2
*
VDR_Q5_1_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q5_1_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE_GGUF
+
(
kyqs
+
l
)
%
WARP_SIZE_GGUF
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE_GGUF
+
(
kyqs
+
l
+
QI5_1
)
%
WARP_SIZE_GGUF
];
}
return
vec_dot_q8_1_q8_1_impl
<
QR5_1
*
VDR_Q5_1_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
2
*
k
],
u
,
x_dm
[
index_bx
],
y_ds
[
j
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE_GGUF
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q8_0_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q8_0
*
bq8_0
=
(
const
block_q8_0
*
)
vbq
;
int
v
[
VDR_Q8_0_Q8_1_MMVQ
];
int
u
[
VDR_Q8_0_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q8_0_Q8_1_MMVQ
;
++
i
)
{
v
[
i
]
=
get_int_from_int8
(
bq8_0
->
qs
,
iqs
+
i
);
u
[
i
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
}
return
vec_dot_q8_0_q8_1_impl
<
VDR_Q8_0_Q8_1_MMVQ
>
(
v
,
u
,
__half2float
(
bq8_0
->
d
),
__low2float
(
bq8_1
->
ds
));
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q8_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI8_0
)
+
mmq_y
/
QI8_0
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q8_0
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI8_0
;
const
int
kqsx
=
k
%
QI8_0
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
const
block_q8_0
*
bx0
=
(
const
block_q8_0
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q8_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
]
=
get_int_from_int8
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI8_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI8_0
)
{
int
i
=
i0
+
i_offset
*
QI8_0
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q8_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI8_0
)
+
i
/
QI8_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
static
__device__
__forceinline__
float
vec_dot_q8_0_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
return
vec_dot_q8_0_q8_1_impl
<
VDR_Q8_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
],
&
y_qs
[
j
*
WARP_SIZE_GGUF
+
k
],
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI8_0
)
+
i
/
QI8_0
+
k
/
QI8_0
],
y_df
[
j
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
k
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q2_K
*
bq2_K
=
(
const
block_q2_K
*
)
vbq
;
const
int
bq8_offset
=
QR2_K
*
(
iqs
/
QI8_1
);
const
int
scale_offset
=
iqs
-
iqs
%
QI8_1
+
(
iqs
%
QI8_1
)
/
(
QI8_1
/
2
);
const
uint8_t
*
scales
=
bq2_K
->
scales
+
scale_offset
;
const
int
v
=
get_int_from_uint8_aligned
(
bq2_K
->
qs
,
iqs
);
int
u
[
QR2_K
];
float
d8
[
QR2_K
];
#pragma unroll
for
(
int
i
=
0
;
i
<
QR2_K
;
++
i
)
{
u
[
i
]
=
get_int_from_int8_aligned
(
bq8_1
[
bq8_offset
+
i
].
qs
,
iqs
%
QI8_1
);
d8
[
i
]
=
__low2float
(
bq8_1
[
bq8_offset
+
i
].
ds
);
}
return
vec_dot_q2_K_q8_1_impl_mmvq
(
v
,
u
,
scales
,
bq2_K
->
dm
,
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q2_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI2_K
)
+
mmq_y
/
QI2_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE_GGUF
/
4
)
+
mmq_y
/
4
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q2_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI2_K
;
const
int
kqsx
=
k
%
QI2_K
;
const
block_q2_K
*
bx0
=
(
const
block_q2_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI2_K
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI2_K
)
{
int
i
=
(
i0
+
i_offset
*
QI2_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI2_K
)
+
i
/
QI2_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
4
)
{
int
i
=
i0
+
i_offset
*
4
+
k
/
(
WARP_SIZE_GGUF
/
4
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE_GGUF
/
4
))
/
(
QI2_K
/
4
);
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
4
)
+
i
/
4
+
k
%
(
WARP_SIZE_GGUF
/
4
)]
=
get_int_from_uint8_aligned
(
bxi
->
scales
,
k
%
(
QI2_K
/
4
));
}
}
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kbx
=
k
/
QI2_K
;
const
int
ky
=
(
k
%
QI2_K
)
*
QR2_K
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
int
v
[
QR2_K
*
VDR_Q2_K_Q8_1_MMQ
];
const
int
kqsx
=
i
*
(
WARP_SIZE_GGUF
+
1
)
+
kbx
*
QI2_K
+
(
QI2_K
/
2
)
*
(
ky
/
(
2
*
QI2_K
))
+
ky
%
(
QI2_K
/
2
);
const
int
shift
=
2
*
((
ky
%
(
2
*
QI2_K
))
/
(
QI2_K
/
2
));
#pragma unroll
for
(
int
l
=
0
;
l
<
QR2_K
*
VDR_Q2_K_Q8_1_MMQ
;
++
l
)
{
v
[
l
]
=
(
x_ql
[
kqsx
+
l
]
>>
shift
)
&
0x03030303
;
}
const
uint8_t
*
scales
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
4
)
+
i
/
4
+
kbx
*
4
])
+
ky
/
4
;
const
int
index_y
=
j
*
WARP_SIZE_GGUF
+
(
QR2_K
*
k
)
%
WARP_SIZE_GGUF
;
return
vec_dot_q2_K_q8_1_impl_mmq
(
v
,
&
y_qs
[
index_y
],
scales
,
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI2_K
)
+
i
/
QI2_K
+
kbx
],
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q3_K
*
bq3_K
=
(
const
block_q3_K
*
)
vbq
;
const
int
bq8_offset
=
QR3_K
*
(
iqs
/
(
QI3_K
/
2
));
const
int
scale_offset
=
iqs
-
iqs
%
QI8_1
+
(
iqs
%
QI8_1
)
/
(
QI8_1
/
2
);
const
float
d
=
__half2float
(
bq3_K
->
d
);
const
int
vl
=
get_int_from_uint8
(
bq3_K
->
qs
,
iqs
);
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
const
int
vh
=
~
get_int_from_uint8
(
bq3_K
->
hmask
,
iqs
%
(
QI3_K
/
2
))
>>
bq8_offset
;
int
u
[
QR3_K
];
float
d8
[
QR3_K
];
#pragma unroll
for
(
int
i
=
0
;
i
<
QR3_K
;
++
i
)
{
u
[
i
]
=
get_int_from_int8_aligned
(
bq8_1
[
bq8_offset
+
i
].
qs
,
iqs
%
QI8_1
);
d8
[
i
]
=
__low2float
(
bq8_1
[
bq8_offset
+
i
].
ds
);
}
return
vec_dot_q3_K_q8_1_impl_mmvq
(
vl
,
vh
,
u
,
bq3_K
->
scales
,
scale_offset
,
d
,
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q3_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI3_K
)
+
mmq_y
/
QI3_K
];
__shared__
int
tile_x_qh
[
mmq_y
*
(
WARP_SIZE_GGUF
/
2
)
+
mmq_y
/
2
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE_GGUF
/
4
)
+
mmq_y
/
4
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_qh
=
tile_x_qh
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q3_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI3_K
;
const
int
kqsx
=
k
%
QI3_K
;
const
block_q3_K
*
bx0
=
(
const
block_q3_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
]
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI3_K
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI3_K
)
{
int
i
=
(
i0
+
i_offset
*
QI3_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI3_K
)
+
i
/
QI3_K
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
2
)
{
int
i
=
i0
+
i_offset
*
2
+
k
/
(
WARP_SIZE_GGUF
/
2
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE_GGUF
/
2
))
/
(
QI3_K
/
2
);
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
x_qh
[
i
*
(
WARP_SIZE_GGUF
/
2
)
+
i
/
2
+
k
%
(
WARP_SIZE_GGUF
/
2
)]
=
~
get_int_from_uint8
(
bxi
->
hmask
,
k
%
(
QI3_K
/
2
));
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
4
)
{
int
i
=
i0
+
i_offset
*
4
+
k
/
(
WARP_SIZE_GGUF
/
4
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE_GGUF
/
4
))
/
(
QI3_K
/
4
);
const
int
ksc
=
k
%
(
QI3_K
/
4
);
const
int
ksc_low
=
ksc
%
(
QI3_K
/
8
);
const
int
shift_low
=
4
*
(
ksc
/
(
QI3_K
/
8
));
const
int
sc_low
=
(
get_int_from_uint8
(
bxi
->
scales
,
ksc_low
)
>>
shift_low
)
&
0x0F0F0F0F
;
const
int
ksc_high
=
QI3_K
/
8
;
const
int
shift_high
=
2
*
ksc
;
const
int
sc_high
=
((
get_int_from_uint8
(
bxi
->
scales
,
ksc_high
)
>>
shift_high
)
<<
4
)
&
0x30303030
;
const
int
sc
=
__vsubss4
(
sc_low
|
sc_high
,
0x20202020
);
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
4
)
+
i
/
4
+
k
%
(
WARP_SIZE_GGUF
/
4
)]
=
sc
;
}
}
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kbx
=
k
/
QI3_K
;
const
int
ky
=
(
k
%
QI3_K
)
*
QR3_K
;
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
const
int8_t
*
scales
=
((
const
int8_t
*
)(
x_sc
+
i
*
(
WARP_SIZE_GGUF
/
4
)
+
i
/
4
+
kbx
*
4
))
+
ky
/
4
;
int
v
[
QR3_K
*
VDR_Q3_K_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
QR3_K
*
VDR_Q3_K_Q8_1_MMQ
;
++
l
)
{
const
int
kqsx
=
i
*
(
WARP_SIZE_GGUF
+
1
)
+
kbx
*
QI3_K
+
(
QI3_K
/
2
)
*
(
ky
/
(
2
*
QI3_K
))
+
ky
%
(
QI3_K
/
2
);
const
int
shift
=
2
*
((
ky
%
32
)
/
8
);
const
int
vll
=
(
x_ql
[
kqsx
+
l
]
>>
shift
)
&
0x03030303
;
const
int
vh
=
x_qh
[
i
*
(
WARP_SIZE_GGUF
/
2
)
+
i
/
2
+
kbx
*
(
QI3_K
/
2
)
+
(
ky
+
l
)
%
8
]
>>
((
ky
+
l
)
/
8
);
const
int
vlh
=
(
vh
<<
2
)
&
0x04040404
;
v
[
l
]
=
__vsubss4
(
vll
,
vlh
);
}
const
int
index_y
=
j
*
WARP_SIZE_GGUF
+
(
k
*
QR3_K
)
%
WARP_SIZE_GGUF
;
return
vec_dot_q3_K_q8_1_impl_mmq
(
v
,
&
y_qs
[
index_y
],
scales
,
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI3_K
)
+
i
/
QI3_K
+
kbx
],
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q4_K
*
bq4_K
=
(
const
block_q4_K
*
)
vbq
;
int
v
[
2
];
int
u
[
2
*
QR4_K
];
float
d8
[
QR4_K
];
// iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
const
int
bq8_offset
=
QR4_K
*
((
iqs
/
2
)
/
(
QI8_1
/
2
));
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
const
int
*
q4
=
(
const
int
*
)(
bq4_K
->
qs
+
16
*
bq8_offset
+
4
*
((
iqs
/
2
)
%
4
));
v
[
0
]
=
q4
[
0
];
v
[
1
]
=
q4
[
4
];
const
uint16_t
*
scales
=
(
const
uint16_t
*
)
bq4_K
->
scales
;
uint16_t
aux
[
2
];
const
int
j
=
bq8_offset
/
2
;
if
(
j
<
2
)
{
aux
[
0
]
=
scales
[
j
+
0
]
&
0x3f3f
;
aux
[
1
]
=
scales
[
j
+
2
]
&
0x3f3f
;
}
else
{
aux
[
0
]
=
((
scales
[
j
+
2
]
>>
0
)
&
0x0f0f
)
|
((
scales
[
j
-
2
]
&
0xc0c0
)
>>
2
);
aux
[
1
]
=
((
scales
[
j
+
2
]
>>
4
)
&
0x0f0f
)
|
((
scales
[
j
-
0
]
&
0xc0c0
)
>>
2
);
}
const
uint8_t
*
sc
=
(
const
uint8_t
*
)
aux
;
const
uint8_t
*
m
=
sc
+
2
;
for
(
int
i
=
0
;
i
<
QR4_K
;
++
i
)
{
const
block_q8_1
*
bq8i
=
bq8_1
+
bq8_offset
+
i
;
d8
[
i
]
=
__low2float
(
bq8i
->
ds
);
const
int
*
q8
=
(
const
int
*
)
bq8i
->
qs
+
((
iqs
/
2
)
%
4
);
u
[
2
*
i
+
0
]
=
q8
[
0
];
u
[
2
*
i
+
1
]
=
q8
[
4
];
}
return
vec_dot_q4_K_q8_1_impl_vmmq
(
v
,
u
,
sc
,
m
,
bq4_K
->
dm
,
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI4_K
)
+
mmq_y
/
QI4_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE_GGUF
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q4_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI4_K
;
// == 0 if QK_K == 256
const
int
kqsx
=
k
%
QI4_K
;
// == k if QK_K == 256
const
block_q4_K
*
bx0
=
(
const
block_q4_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI4_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI4_K
)
{
int
i
=
(
i0
+
i_offset
*
QI4_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI4_K
)
+
i
/
QI4_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE_GGUF
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE_GGUF
/
8
))
/
(
QI4_K
/
8
);
const
int
*
scales
=
(
const
int
*
)
bxi
->
scales
;
const
int
ksc
=
k
%
(
WARP_SIZE_GGUF
/
8
);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int
scales8
=
(
scales
[(
ksc
%
2
)
+
(
ksc
!=
0
)]
>>
(
4
*
(
ksc
&
(
ksc
/
2
))))
&
0x0F0F0F0F
;
// lower 4 bits
scales8
|=
(
scales
[
ksc
/
2
]
>>
(
2
*
(
ksc
%
2
)))
&
0x30303030
;
// upper 2 bits
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
8
)
+
i
/
8
+
ksc
]
=
scales8
;
}
}
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
(
void
)
x_qh
;
const
uint8_t
*
sc
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
8
)
+
i
/
8
+
k
/
16
])
+
2
*
((
k
%
16
)
/
8
);
const
int
index_y
=
j
*
WARP_SIZE_GGUF
+
(
QR4_K
*
k
)
%
WARP_SIZE_GGUF
;
return
vec_dot_q4_K_q8_1_impl_mmq
(
&
x_ql
[
i
*
(
WARP_SIZE_GGUF
+
1
)
+
k
],
&
y_qs
[
index_y
],
sc
,
sc
+
8
,
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI4_K
)
+
i
/
QI4_K
],
&
y_ds
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q5_K
*
bq5_K
=
(
const
block_q5_K
*
)
vbq
;
int
vl
[
2
];
int
vh
[
2
];
int
u
[
2
*
QR5_K
];
float
d8
[
QR5_K
];
const
int
bq8_offset
=
QR5_K
*
((
iqs
/
2
)
/
(
QI8_1
/
2
));
const
int
*
ql
=
(
const
int
*
)(
bq5_K
->
qs
+
16
*
bq8_offset
+
4
*
((
iqs
/
2
)
%
4
));
const
int
*
qh
=
(
const
int
*
)(
bq5_K
->
qh
+
4
*
((
iqs
/
2
)
%
4
));
vl
[
0
]
=
ql
[
0
];
vl
[
1
]
=
ql
[
4
];
vh
[
0
]
=
qh
[
0
]
>>
bq8_offset
;
vh
[
1
]
=
qh
[
4
]
>>
bq8_offset
;
const
uint16_t
*
scales
=
(
const
uint16_t
*
)
bq5_K
->
scales
;
uint16_t
aux
[
2
];
const
int
j
=
bq8_offset
/
2
;
if
(
j
<
2
)
{
aux
[
0
]
=
scales
[
j
+
0
]
&
0x3f3f
;
aux
[
1
]
=
scales
[
j
+
2
]
&
0x3f3f
;
}
else
{
aux
[
0
]
=
((
scales
[
j
+
2
]
>>
0
)
&
0x0f0f
)
|
((
scales
[
j
-
2
]
&
0xc0c0
)
>>
2
);
aux
[
1
]
=
((
scales
[
j
+
2
]
>>
4
)
&
0x0f0f
)
|
((
scales
[
j
-
0
]
&
0xc0c0
)
>>
2
);
}
const
uint8_t
*
sc
=
(
const
uint8_t
*
)
aux
;
const
uint8_t
*
m
=
sc
+
2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR5_K
;
++
i
)
{
const
block_q8_1
*
bq8i
=
bq8_1
+
bq8_offset
+
i
;
d8
[
i
]
=
__low2float
(
bq8i
->
ds
);
const
int
*
q8
=
(
const
int
*
)
bq8i
->
qs
+
((
iqs
/
2
)
%
4
);
u
[
2
*
i
+
0
]
=
q8
[
0
];
u
[
2
*
i
+
1
]
=
q8
[
4
];
}
return
vec_dot_q5_K_q8_1_impl_vmmq
(
vl
,
vh
,
u
,
sc
,
m
,
bq5_K
->
dm
,
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI5_K
)
+
mmq_y
/
QI5_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE_GGUF
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q5_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI5_K
;
// == 0 if QK_K == 256
const
int
kqsx
=
k
%
QI5_K
;
// == k if QK_K == 256
const
block_q5_K
*
bx0
=
(
const
block_q5_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
const
int
ky
=
QR5_K
*
kqsx
;
const
int
ql
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
const
int
ql0
=
(
ql
>>
0
)
&
0x0F0F0F0F
;
const
int
ql1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
const
int
qh
=
get_int_from_uint8_aligned
(
bxi
->
qh
,
kqsx
%
(
QI5_K
/
4
));
const
int
qh0
=
((
qh
>>
(
2
*
(
kqsx
/
(
QI5_K
/
4
))
+
0
))
<<
4
)
&
0x10101010
;
const
int
qh1
=
((
qh
>>
(
2
*
(
kqsx
/
(
QI5_K
/
4
))
+
1
))
<<
4
)
&
0x10101010
;
const
int
kq0
=
ky
-
ky
%
(
QI5_K
/
2
)
+
k
%
(
QI5_K
/
4
)
+
0
;
const
int
kq1
=
ky
-
ky
%
(
QI5_K
/
2
)
+
k
%
(
QI5_K
/
4
)
+
(
QI5_K
/
4
);
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
kq0
]
=
ql0
|
qh0
;
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
kq1
]
=
ql1
|
qh1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI5_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI5_K
)
{
int
i
=
(
i0
+
i_offset
*
QI5_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI5_K
)
+
i
/
QI5_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE_GGUF
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE_GGUF
/
8
))
/
(
QI5_K
/
8
);
const
int
*
scales
=
(
const
int
*
)
bxi
->
scales
;
const
int
ksc
=
k
%
(
WARP_SIZE_GGUF
/
8
);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int
scales8
=
(
scales
[(
ksc
%
2
)
+
(
ksc
!=
0
)]
>>
(
4
*
(
ksc
&
(
ksc
/
2
))))
&
0x0F0F0F0F
;
// lower 4 bits
scales8
|=
(
scales
[
ksc
/
2
]
>>
(
2
*
(
ksc
%
2
)))
&
0x30303030
;
// upper 2 bits
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
8
)
+
i
/
8
+
ksc
]
=
scales8
;
}
}
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
uint8_t
*
sc
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
8
)
+
i
/
8
+
k
/
16
])
+
2
*
((
k
%
16
)
/
8
);
const
int
index_x
=
i
*
(
QR5_K
*
WARP_SIZE_GGUF
+
1
)
+
QR5_K
*
k
;
const
int
index_y
=
j
*
WARP_SIZE_GGUF
+
(
QR5_K
*
k
)
%
WARP_SIZE_GGUF
;
return
vec_dot_q5_K_q8_1_impl_mmq
(
&
x_ql
[
index_x
],
&
y_qs
[
index_y
],
sc
,
sc
+
8
,
x_dm
[
i
*
(
WARP_SIZE_GGUF
/
QI5_K
)
+
i
/
QI5_K
],
&
y_ds
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q6_K
*
bq6_K
=
(
const
block_q6_K
*
)
vbq
;
const
int
bq8_offset
=
2
*
QR6_K
*
(
iqs
/
(
QI6_K
/
2
))
+
(
iqs
%
(
QI6_K
/
2
))
/
(
QI6_K
/
4
);
const
int
scale_offset
=
(
QI6_K
/
4
)
*
(
iqs
/
(
QI6_K
/
2
))
+
(
iqs
%
(
QI6_K
/
2
))
/
(
QI6_K
/
8
);
const
int
vh_shift
=
2
*
((
iqs
%
(
QI6_K
/
2
))
/
(
QI6_K
/
4
));
const
int
vl
=
get_int_from_uint8
(
bq6_K
->
ql
,
iqs
);
const
int
vh
=
get_int_from_uint8
(
bq6_K
->
qh
,
(
QI6_K
/
4
)
*
(
iqs
/
(
QI6_K
/
2
))
+
iqs
%
(
QI6_K
/
4
))
>>
vh_shift
;
const
int8_t
*
scales
=
bq6_K
->
scales
+
scale_offset
;
int
u
[
QR6_K
];
float
d8
[
QR6_K
];
#pragma unroll
for
(
int
i
=
0
;
i
<
QR6_K
;
++
i
)
{
u
[
i
]
=
get_int_from_int8_aligned
(
bq8_1
[
bq8_offset
+
2
*
i
].
qs
,
iqs
%
QI8_1
);
d8
[
i
]
=
__low2float
(
bq8_1
[
bq8_offset
+
2
*
i
].
ds
);
}
return
vec_dot_q6_K_q8_1_impl_mmvq
(
vl
,
vh
,
u
,
scales
,
__half2float
(
bq6_K
->
d
),
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q6_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE_GGUF
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE_GGUF
/
QI6_K
)
+
mmq_y
/
QI6_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE_GGUF
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q6_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI6_K
;
// == 0 if QK_K == 256
const
int
kqsx
=
k
%
QI6_K
;
// == k if QK_K == 256
const
block_q6_K
*
bx0
=
(
const
block_q6_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
const
int
ky
=
QR6_K
*
kqsx
;
const
int
ql
=
get_int_from_uint8
(
bxi
->
ql
,
kqsx
);
const
int
ql0
=
(
ql
>>
0
)
&
0x0F0F0F0F
;
const
int
ql1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
const
int
qh
=
get_int_from_uint8
(
bxi
->
qh
,
(
QI6_K
/
4
)
*
(
kqsx
/
(
QI6_K
/
2
))
+
kqsx
%
(
QI6_K
/
4
));
const
int
qh0
=
((
qh
>>
(
2
*
((
kqsx
%
(
QI6_K
/
2
))
/
(
QI6_K
/
4
))))
<<
4
)
&
0x30303030
;
const
int
qh1
=
(
qh
>>
(
2
*
((
kqsx
%
(
QI6_K
/
2
))
/
(
QI6_K
/
4
))))
&
0x30303030
;
const
int
kq0
=
ky
-
ky
%
QI6_K
+
k
%
(
QI6_K
/
2
)
+
0
;
const
int
kq1
=
ky
-
ky
%
QI6_K
+
k
%
(
QI6_K
/
2
)
+
(
QI6_K
/
2
);
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
kq0
]
=
__vsubss4
(
ql0
|
qh0
,
0x20202020
);
x_ql
[
i
*
(
2
*
WARP_SIZE_GGUF
+
1
)
+
kq1
]
=
__vsubss4
(
ql1
|
qh1
,
0x20202020
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE_GGUF
/
QI6_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
float
*
x_dmf
=
(
float
*
)
x_dm
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI6_K
)
{
int
i
=
(
i0
+
i_offset
*
QI6_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI6_K
)
+
i
/
QI6_K
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE_GGUF
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE_GGUF
/
8
))
/
4
;
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
8
)
+
i
/
8
+
k
%
(
WARP_SIZE_GGUF
/
8
)]
=
get_int_from_int8
(
bxi
->
scales
,
k
%
(
QI6_K
/
8
));
}
}
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
const
int8_t
*
sc
=
((
const
int8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE_GGUF
/
8
)
+
i
/
8
+
k
/
8
]);
const
int
index_x
=
i
*
(
QR6_K
*
WARP_SIZE_GGUF
+
1
)
+
QR6_K
*
k
;
const
int
index_y
=
j
*
WARP_SIZE_GGUF
+
(
QR6_K
*
k
)
%
WARP_SIZE_GGUF
;
return
vec_dot_q6_K_q8_1_impl_mmq
(
&
x_ql
[
index_x
],
&
y_qs
[
index_y
],
sc
,
x_dmf
[
i
*
(
WARP_SIZE_GGUF
/
QI6_K
)
+
i
/
QI6_K
],
&
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_iq2_xxs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_iq2_xxs
*
bq2
=
(
const
block_iq2_xxs
*
)
vbq
;
const
int
ib32
=
iqs
;
const
uint16_t
*
q2
=
bq2
->
qs
+
4
*
ib32
;
const
uint8_t
*
aux8
=
(
const
uint8_t
*
)
q2
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
uint32_t
aux32
=
q2
[
2
]
|
(
q2
[
3
]
<<
16
);
int
sumi
=
0
;
for
(
int
l
=
0
;
l
<
4
;
++
l
)
{
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xxs_grid
+
aux8
[
l
]);
const
uint8_t
signs
=
ksigns_iq2xs
[
aux32
&
127
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
sumi
+=
q8
[
j
]
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1
:
1
);
}
q8
+=
8
;
aux32
>>=
7
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
(
0.5
f
+
aux32
)
*
__half2float
(
bq8_1
[
ib32
].
ds
.
x
)
*
0.25
f
;
return
d
*
sumi
;
}
static
__device__
__forceinline__
float
vec_dot_iq2_xs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_iq2_xs
*
bq2
=
(
const
block_iq2_xs
*
)
vbq
;
const
int
ib32
=
iqs
;
const
uint16_t
*
q2
=
bq2
->
qs
+
4
*
ib32
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
const
uint8_t
ls1
=
bq2
->
scales
[
ib32
]
&
0xf
;
const
uint8_t
ls2
=
bq2
->
scales
[
ib32
]
>>
4
;
int
sumi1
=
0
;
for
(
int
l
=
0
;
l
<
2
;
++
l
)
{
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xs_grid
+
(
q2
[
l
]
&
511
));
const
uint8_t
signs
=
ksigns_iq2xs
[
q2
[
l
]
>>
9
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
sumi1
+=
q8
[
j
]
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1
:
1
);
}
q8
+=
8
;
}
int
sumi2
=
0
;
for
(
int
l
=
2
;
l
<
4
;
++
l
)
{
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xs_grid
+
(
q2
[
l
]
&
511
));
const
uint8_t
signs
=
ksigns_iq2xs
[
q2
[
l
]
>>
9
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
sumi2
+=
q8
[
j
]
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1
:
1
);
}
q8
+=
8
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
__half2float
(
bq8_1
[
ib32
].
ds
.
x
)
*
0.25
f
;
return
d
*
((
0.5
f
+
ls1
)
*
sumi1
+
(
0.5
f
+
ls2
)
*
sumi2
);
}
static
__device__
__forceinline__
float
vec_dot_iq2_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const
block_iq2_s
*
bq2
=
(
const
block_iq2_s
*
)
vbq
;
const
int
ib32
=
iqs
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
const
uint8_t
*
signs
=
bq2
->
qs
+
QK_K
/
8
+
4
*
ib32
;
const
uint8_t
ls1
=
bq2
->
scales
[
ib32
]
&
0xf
;
const
uint8_t
ls2
=
bq2
->
scales
[
ib32
]
>>
4
;
int
sumi1
=
0
;
for
(
int
l
=
0
;
l
<
2
;
++
l
)
{
const
uint32_t
*
grid
=
(
const
uint32_t
*
)(
iq2s_grid
+
(
bq2
->
qs
[
4
*
ib32
+
l
]
|
((
bq2
->
qh
[
ib32
]
<<
(
8
-
2
*
l
))
&
0x300
)));
const
uint32_t
signs0
=
__vcmpeq4
(((
signs
[
l
]
&
0xf
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
uint32_t
signs1
=
__vcmpeq4
(((
signs
[
l
]
>>
4
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
int
grid_l
=
__vsub4
(
grid
[
0
]
^
signs0
,
signs0
);
const
int
grid_h
=
__vsub4
(
grid
[
1
]
^
signs1
,
signs1
);
sumi1
=
__dp4a
(
grid_l
,
*
((
const
int
*
)
q8
+
0
),
sumi1
);
sumi1
=
__dp4a
(
grid_h
,
*
((
const
int
*
)
q8
+
1
),
sumi1
);
q8
+=
8
;
}
int
sumi2
=
0
;
for
(
int
l
=
2
;
l
<
4
;
++
l
)
{
const
uint32_t
*
grid
=
(
const
uint32_t
*
)(
iq2s_grid
+
(
bq2
->
qs
[
4
*
ib32
+
l
]
|
((
bq2
->
qh
[
ib32
]
<<
(
8
-
2
*
l
))
&
0x300
)));
const
uint32_t
signs0
=
__vcmpeq4
(((
signs
[
l
]
&
0xf
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
uint32_t
signs1
=
__vcmpeq4
(((
signs
[
l
]
>>
4
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
int
grid_l
=
__vsub4
(
grid
[
0
]
^
signs0
,
signs0
);
const
int
grid_h
=
__vsub4
(
grid
[
1
]
^
signs1
,
signs1
);
sumi2
=
__dp4a
(
grid_l
,
*
((
const
int
*
)
q8
+
0
),
sumi2
);
sumi2
=
__dp4a
(
grid_h
,
*
((
const
int
*
)
q8
+
1
),
sumi2
);
q8
+=
8
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
__low2float
(
bq8_1
[
ib32
].
ds
)
*
0.25
f
;
return
d
*
((
0.5
f
+
ls1
)
*
sumi1
+
(
0.5
f
+
ls2
)
*
sumi2
);
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq3_xxs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const
block_iq3_xxs
*
bq2
=
(
const
block_iq3_xxs
*
)
vbq
;
const
int
ib32
=
iqs
;
const
uint8_t
*
q3
=
bq2
->
qs
+
8
*
ib32
;
const
uint16_t
*
gas
=
(
const
uint16_t
*
)(
bq2
->
qs
+
QK_K
/
4
)
+
2
*
ib32
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
uint32_t
aux32
=
gas
[
0
]
|
(
gas
[
1
]
<<
16
);
int
sumi
=
0
;
for
(
int
l
=
0
;
l
<
4
;
++
l
)
{
const
uint32_t
*
grid1
=
iq3xxs_grid
+
q3
[
2
*
l
+
0
];
const
uint32_t
*
grid2
=
iq3xxs_grid
+
q3
[
2
*
l
+
1
];
const
uint32_t
*
signs
=
(
const
uint32_t
*
)(
ksigns64
+
(
aux32
&
127
));
const
int
grid_l
=
__vsub4
(
grid1
[
0
]
^
signs
[
0
],
signs
[
0
]);
const
int
grid_h
=
__vsub4
(
grid2
[
0
]
^
signs
[
1
],
signs
[
1
]);
sumi
=
__dp4a
(
grid_l
,
*
((
int
*
)
q8
+
0
),
sumi
);
sumi
=
__dp4a
(
grid_h
,
*
((
int
*
)
q8
+
1
),
sumi
);
q8
+=
8
;
aux32
>>=
7
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
(
0.5
f
+
aux32
)
*
__low2float
(
bq8_1
[
ib32
].
ds
)
*
0.5
f
;
return
d
*
sumi
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq3_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const
block_iq3_s
*
bq2
=
(
const
block_iq3_s
*
)
vbq
;
const
int
ib32
=
iqs
;
const
uint8_t
*
qs
=
bq2
->
qs
+
8
*
ib32
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
int
sumi
=
0
;
for
(
int
l
=
0
;
l
<
4
;
++
l
)
{
const
uint32_t
*
grid1
=
iq3xs_grid
+
(
qs
[
2
*
l
+
0
]
|
((
bq2
->
qh
[
ib32
]
<<
(
8
-
2
*
l
))
&
256
));
const
uint32_t
*
grid2
=
iq3xs_grid
+
(
qs
[
2
*
l
+
1
]
|
((
bq2
->
qh
[
ib32
]
<<
(
7
-
2
*
l
))
&
256
));
uint32_t
signs0
=
__vcmpeq4
(((
bq2
->
signs
[
4
*
ib32
+
l
]
&
0xf
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
uint32_t
signs1
=
__vcmpeq4
(((
bq2
->
signs
[
4
*
ib32
+
l
]
>>
4
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
int
grid_l
=
__vsub4
(
grid1
[
0
]
^
signs0
,
signs0
);
const
int
grid_h
=
__vsub4
(
grid2
[
0
]
^
signs1
,
signs1
);
sumi
=
__dp4a
(
grid_l
,
*
((
int
*
)
q8
+
0
),
sumi
);
sumi
=
__dp4a
(
grid_h
,
*
((
int
*
)
q8
+
1
),
sumi
);
q8
+=
8
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
(
0.5
f
+
((
bq2
->
scales
[
ib32
/
2
]
>>
4
*
(
ib32
%
2
))
&
0xf
))
*
__low2float
(
bq8_1
[
ib32
].
ds
)
*
0.5
f
;
return
d
*
sumi
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq1_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const
block_iq1_s
*
bq1
=
(
const
block_iq1_s
*
)
vbq
;
const
int
qs_packed
=
get_int_b2
(
bq1
->
qs
,
iqs
);
const
uint8_t
*
qs
=
(
const
uint8_t
*
)
&
qs_packed
;
const
int
qh
=
bq1
->
qh
[
iqs
];
int
sumi
=
0
;
#pragma unroll
for
(
int
l0
=
0
;
l0
<
8
;
l0
+=
2
)
{
const
int
grid
=
iq1s_grid_gpu
[
qs
[
l0
/
2
]
|
(((
qh
>>
3
*
(
l0
/
2
))
&
0x07
)
<<
8
)];
const
int
grid0
=
(
grid
>>
0
)
&
0x0F0F0F0F
;
const
int
grid1
=
(
grid
>>
4
)
&
0x0F0F0F0F
;
const
int
u0
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
0
);
const
int
u1
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
1
);
sumi
=
__dp4a
(
grid0
,
u0
,
sumi
);
sumi
=
__dp4a
(
grid1
,
u1
,
sumi
);
}
const
float
d1q
=
__half2float
(
bq1
->
d
)
*
(((
qh
>>
11
)
&
0x0E
)
+
1
);
const
float
delta
=
-
1.0
f
+
IQ1S_DELTA
-
(
qh
&
0x8000
)
*
(
2.0
f
*
IQ1S_DELTA
/
0x8000
);
const
float2
ds
=
__half22float2
(
bq8_1
[
iqs
].
ds
);
return
d1q
*
(
ds
.
x
*
sumi
+
ds
.
y
*
delta
);
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq1_m_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const
block_iq1_m
*
bq1
=
(
const
block_iq1_m
*
)
vbq
;
const
int
qs_packed
=
get_int_b4
(
bq1
->
qs
,
iqs
);
const
uint8_t
*
qs
=
(
const
uint8_t
*
)
&
qs_packed
;
int
sumi
[
2
]
=
{
0
};
float
sumf
[
2
]
=
{
0.0
f
};
#pragma unroll
for
(
int
l0
=
0
;
l0
<
8
;
l0
+=
2
)
{
const
int
qhl
=
bq1
->
qh
[
2
*
iqs
+
l0
/
4
]
>>
(
4
*
((
l0
/
2
)
%
2
));
const
int
grid
=
iq1s_grid_gpu
[
qs
[
l0
/
2
]
|
((
qhl
&
0x07
)
<<
8
)];
const
int
grid0
=
(
grid
>>
0
)
&
0x0F0F0F0F
;
const
int
grid1
=
(
grid
>>
4
)
&
0x0F0F0F0F
;
const
int
u0
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
0
);
const
int
u1
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
1
);
sumi
[
l0
/
4
]
=
__dp4a
(
grid0
,
u0
,
sumi
[
l0
/
4
]);
sumi
[
l0
/
4
]
=
__dp4a
(
grid1
,
u1
,
sumi
[
l0
/
4
]);
const
float
delta
=
-
1.0
f
+
IQ1M_DELTA
-
(
qhl
&
0x08
)
*
(
2.0
f
*
IQ1M_DELTA
/
0x08
);
int
sumy
=
0
;
sumy
=
__dp4a
(
u0
,
0x01010101
,
sumy
);
sumy
=
__dp4a
(
u1
,
0x01010101
,
sumy
);
sumf
[
l0
/
4
]
+=
delta
*
sumy
;
}
const
uint16_t
*
sc
=
(
const
uint16_t
*
)
bq1
->
scales
;
iq1m_scale_t
scale
;
scale
.
u16
=
(
sc
[
0
]
>>
12
)
|
((
sc
[
1
]
>>
8
)
&
0x00F0
)
|
((
sc
[
2
]
>>
4
)
&
0x0F00
)
|
(
sc
[
3
]
&
0xF000
);
const
float
d
=
__half2float
(
scale
.
f16
)
*
__low2float
(
bq8_1
[
iqs
].
ds
);
const
int
tmp
=
sc
[
iqs
/
2
]
>>
(
6
*
(
iqs
%
2
));
const
int
sc0
=
2
*
((
tmp
>>
0
)
&
0x07
)
+
1
;
const
int
sc1
=
2
*
((
tmp
>>
3
)
&
0x07
)
+
1
;
return
d
*
((
sumi
[
0
]
+
sumf
[
0
])
*
sc0
+
(
sumi
[
1
]
+
sumf
[
1
])
*
sc1
);
#endif
}
static
__device__
__forceinline__
void
get_int_from_table_16
(
const
uint32_t
&
q4
,
const
uint8_t
*
values
,
int
&
val1
,
int
&
val2
)
{
uint32_t
aux32
;
const
uint8_t
*
q8
=
(
const
uint8_t
*
)
&
aux32
;
aux32
=
q4
&
0x0f0f0f0f
;
uint16_t
v1
=
values
[
q8
[
0
]]
|
(
values
[
q8
[
1
]]
<<
8
);
uint16_t
v2
=
values
[
q8
[
2
]]
|
(
values
[
q8
[
3
]]
<<
8
);
val1
=
v1
|
(
v2
<<
16
);
aux32
=
(
q4
>>
4
)
&
0x0f0f0f0f
;
v1
=
values
[
q8
[
0
]]
|
(
values
[
q8
[
1
]]
<<
8
);
v2
=
values
[
q8
[
2
]]
|
(
values
[
q8
[
3
]]
<<
8
);
val2
=
v1
|
(
v2
<<
16
);
}
static
__device__
__forceinline__
float
vec_dot_iq4_nl_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const
block_iq4_nl
*
bq
=
(
const
block_iq4_nl
*
)
vbq
;
const
uint16_t
*
q4
=
(
const
uint16_t
*
)
bq
->
qs
+
2
*
iqs
;
const
int32_t
*
q8
=
(
const
int32_t
*
)
bq8_1
->
qs
+
iqs
;
const
uint8_t
*
values
=
(
const
uint8_t
*
)
kvalues_iq4nl
;
int
v1
,
v2
;
int
sumi1
=
0
,
sumi2
=
0
;
for
(
int
l
=
0
;
l
<
VDR_Q4_0_Q8_1_MMVQ
;
++
l
)
{
const
uint32_t
aux
=
q4
[
2
*
l
]
|
(
q4
[
2
*
l
+
1
]
<<
16
);
get_int_from_table_16
(
aux
,
values
,
v1
,
v2
);
sumi1
=
__dp4a
(
v1
,
q8
[
l
+
0
],
sumi1
);
sumi2
=
__dp4a
(
v2
,
q8
[
l
+
4
],
sumi2
);
}
const
float
d
=
__half2float
(
bq
->
d
)
*
__low2float
(
bq8_1
->
ds
);
return
d
*
(
sumi1
+
sumi2
);
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq4_xs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const
block_iq4_xs
*
bq4
=
(
const
block_iq4_xs
*
)
vbq
;
const
uint8_t
*
values
=
(
const
uint8_t
*
)
kvalues_iq4nl
;
// iqs is 0...7
const
int
ib32
=
iqs
;
const
int32_t
*
q8
=
(
const
int
*
)
bq8_1
[
ib32
].
qs
;
const
uint32_t
*
q4
=
(
const
uint32_t
*
)
bq4
->
qs
+
4
*
ib32
;
const
int8_t
ls
=
((
bq4
->
scales_l
[
ib32
/
2
]
>>
4
*
(
ib32
%
2
))
&
0xf
)
|
(((
bq4
->
scales_h
>>
2
*
ib32
)
&
3
)
<<
4
);
const
float
d
=
__half2float
(
bq4
->
d
)
*
(
ls
-
32
)
*
__low2float
(
bq8_1
[
ib32
].
ds
);
int
v1
,
v2
;
int
sumi1
=
0
,
sumi2
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
get_int_from_table_16
(
q4
[
j
],
values
,
v1
,
v2
);
sumi1
=
__dp4a
(
v1
,
q8
[
j
+
0
],
sumi1
);
sumi2
=
__dp4a
(
v2
,
q8
[
j
+
4
],
sumi2
);
}
return
d
*
(
sumi1
+
sumi2
);
#endif
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
8fdcd98e
...
...
@@ -186,6 +186,32 @@ void fast_topk_transform_interface(
void
gelu_quick
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
input
);
#endif
/*
* From gguf quantization
*/
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
,
std
::
optional
<
at
::
ScalarType
>
const
&
dtype
);
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_moe_a8
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
torch
::
Tensor
ggml_moe_a8_vec
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
topk_ids
,
int64_t
top_k
,
int64_t
type
,
int64_t
row
,
int64_t
tokens
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
/*
* From csrc/gemm
*/
...
...
@@ -306,6 +332,8 @@ void topk_softmax(
void
moe_sum_reduce
(
at
::
Tensor
&
input
,
at
::
Tensor
&
output
,
double
routed_scaling_factor
);
void
moe_sum
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output
);
std
::
vector
<
at
::
Tensor
>
moe_fused_gate
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
...
...
sgl-kernel/include/utils.h
View file @
8fdcd98e
...
...
@@ -19,6 +19,10 @@ limitations under the License.
#include <cuda_runtime.h>
#include <torch/all.h>
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifdef USE_ROCM
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
#define _DISPATCH_CASE_F16(c_type, ...) \
...
...
@@ -326,6 +330,13 @@ inline bool getEnvEnablePDL() {
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define DISPATCH_CASE_FLOAT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOAT_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM
...
...
@@ -447,3 +458,12 @@ inline uint32_t next_pow2(uint32_t x) noexcept {
if
(
x
<=
1
)
return
1
;
return
1u
<<
(
32
-
__builtin_clz
(
x
-
1
));
}
/*
* LDG Support
*/
#ifndef USE_ROCM
#define SGLANG_LDG(arg) __ldg(arg)
#else
#define SGLANG_LDG(arg) *(arg)
#endif
sgl-kernel/python/sgl_kernel/__init__.py
View file @
8fdcd98e
...
...
@@ -288,10 +288,19 @@ from sgl_kernel.moe import (
fp8_blockwise_scaled_grouped_mm
,
moe_align_block_size
,
moe_fused_gate
,
moe_sum
,
moe_sum_reduce
,
prepare_moe_input
,
topk_softmax
,
)
from
sgl_kernel.quantization
import
(
ggml_dequantize
,
ggml_moe_a8
,
ggml_moe_a8_vec
,
ggml_moe_get_block_size
,
ggml_mul_mat_a8
,
ggml_mul_mat_vec_a8
,
)
from
sgl_kernel.sampling
import
(
min_p_sampling_from_probs
,
top_k_mask_logits
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
8fdcd98e
...
...
@@ -48,6 +48,16 @@ def moe_sum_reduce(
)
def
moe_sum
(
input_tensor
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
):
torch
.
ops
.
sgl_kernel
.
moe_sum
.
default
(
input_tensor
,
output_tensor
,
)
def
moe_fused_gate
(
input_tensor
,
bias
,
...
...
sgl-kernel/python/sgl_kernel/quantization/__init__.py
0 → 100644
View file @
8fdcd98e
from
.gguf
import
(
ggml_dequantize
,
ggml_moe_a8
,
ggml_moe_a8_vec
,
ggml_moe_get_block_size
,
ggml_mul_mat_a8
,
ggml_mul_mat_vec_a8
,
)
sgl-kernel/python/sgl_kernel/quantization/gguf.py
0 → 100644
View file @
8fdcd98e
import
torch
def
ggml_dequantize
(
weight
:
torch
.
Tensor
,
quant_type
:
int
,
M
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
):
assert
M
>
0
and
N
>
0
,
"GGUF weight Input shape must be of positive dimensions"
return
torch
.
ops
.
sgl_kernel
.
ggml_dequantize
.
default
(
weight
,
quant_type
,
M
,
N
,
dtype
)
def
ggml_mul_mat_vec_a8
(
weight
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
ggml_mul_mat_vec_a8
.
default
(
weight
,
x
,
quant_type
,
row
)
def
ggml_mul_mat_a8
(
weight
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
ggml_mul_mat_a8
.
default
(
weight
,
x
,
quant_type
,
row
)
def
ggml_moe_a8
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_token_post_padded
:
torch
.
Tensor
,
type
:
int
,
row
:
int
,
topk
:
int
,
tokens
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
ggml_moe_a8
.
default
(
input
,
weight
,
sorted_token_ids
,
expert_ids
,
num_token_post_padded
,
type
,
row
,
topk
,
tokens
,
)
def
ggml_moe_a8_vec
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
top_k
:
int
,
type
:
int
,
row
:
int
,
tokens
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
ggml_moe_a8_vec
.
default
(
input
,
weight
,
topk_ids
,
top_k
,
type
,
row
,
tokens
)
def
ggml_moe_get_block_size
(
type
:
int
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
ggml_moe_get_block_size
.
default
(
type
)
sgl-kernel/tests/test_gguf.py
0 → 100644
View file @
8fdcd98e
# SPDX-License-Identifier: Apache-2.0
import
random
from
pathlib
import
Path
import
numpy
as
np
import
pytest
import
torch
from
gguf
import
GGMLQuantizationType
,
GGUFReader
,
ReaderTensor
,
dequantize
from
huggingface_hub
import
snapshot_download
from
sgl_kernel
import
(
ggml_dequantize
,
ggml_moe_a8
,
ggml_moe_a8_vec
,
ggml_moe_get_block_size
,
ggml_mul_mat_a8
,
ggml_mul_mat_vec_a8
,
)
GGUF_SAMPLE
=
snapshot_download
(
"Isotr0py/test-gguf-sample"
)
GGUF_SAMPLE_MOE
=
snapshot_download
(
"SzymonOzog/test-gguf-moe-sample"
)
def
get_gguf_sample_tensors
(
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
list
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
return
GGUFReader
(
sample_file
).
tensors
def
get_gguf_MoE_tensors
(
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
list
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE_MOE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
return
GGUFReader
(
sample_file
).
tensors
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES
=
[
256
,
1024
]
NUM_TOKENS
=
[
7
,
2050
]
# Arbitrary values for testing
SEEDS
=
[
0
]
QUANT_TYPES
=
[
# i-matrix
GGMLQuantizationType
.
IQ1_M
,
GGMLQuantizationType
.
IQ1_S
,
GGMLQuantizationType
.
IQ2_S
,
GGMLQuantizationType
.
IQ2_XS
,
GGMLQuantizationType
.
IQ3_S
,
GGMLQuantizationType
.
IQ3_XXS
,
GGMLQuantizationType
.
IQ4_NL
,
GGMLQuantizationType
.
IQ4_XS
,
# k-quants
GGMLQuantizationType
.
Q2_K
,
GGMLQuantizationType
.
Q3_K
,
GGMLQuantizationType
.
Q4_K
,
GGMLQuantizationType
.
Q5_K
,
GGMLQuantizationType
.
Q6_K
,
# standard quantization
GGMLQuantizationType
.
Q4_0
,
GGMLQuantizationType
.
Q5_0
,
GGMLQuantizationType
.
Q8_0
,
]
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
QUANT_TYPES
)
@
torch
.
inference_mode
()
def
test_dequantize
(
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
GGMLQuantizationType
):
tensors
=
get_gguf_sample_tensors
(
hidden_size
,
quant_type
)
for
tensor
in
tensors
:
shape_str
=
tensor
.
name
.
split
(
"_"
)[
-
1
]
shape
=
map
(
int
,
shape_str
.
split
(
"x"
))
ref_output
=
torch
.
tensor
(
dequantize
(
tensor
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
output
=
ggml_dequantize
(
torch
.
tensor
(
tensor
.
data
,
device
=
"cuda"
),
quant_type
,
*
list
(
shape
),
dtype
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
4e-2
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
QUANT_TYPES
)
@
torch
.
inference_mode
()
def
test_mmvq
(
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
GGMLQuantizationType
):
tensors
=
get_gguf_sample_tensors
(
hidden_size
,
quant_type
)
x
=
torch
.
rand
((
1
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
for
tensor
in
tensors
:
weight
=
torch
.
tensor
(
dequantize
(
tensor
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
ref_output
=
x
@
weight
.
T
qweight
=
torch
.
tensor
(
tensor
.
data
,
device
=
"cuda"
)
output
=
ggml_mul_mat_vec_a8
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
]).
to
(
dtype
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
,
rtol
=
1e-1
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
# k-quants
GGMLQuantizationType
.
Q2_K
,
GGMLQuantizationType
.
Q3_K
,
GGMLQuantizationType
.
Q4_K
,
GGMLQuantizationType
.
Q5_K
,
GGMLQuantizationType
.
Q6_K
,
# standard quants
GGMLQuantizationType
.
Q4_0
,
GGMLQuantizationType
.
Q5_0
,
GGMLQuantizationType
.
Q8_0
,
],
)
@
torch
.
inference_mode
()
def
test_mmq
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
GGMLQuantizationType
,
):
tensors
=
get_gguf_sample_tensors
(
hidden_size
,
quant_type
)
x
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
for
tensor
in
tensors
:
weight
=
torch
.
tensor
(
dequantize
(
tensor
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
ref_output
=
x
@
weight
.
T
qweight
=
torch
.
tensor
(
tensor
.
data
,
device
=
"cuda"
)
output
=
ggml_mul_mat_a8
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
])
atols
=
{
torch
.
half
:
1
,
torch
.
bfloat16
:
1.5
,
torch
.
float
:
1.2
}
# test matrix has inputs centered around 0 and lower precision from
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols
=
{
torch
.
half
:
1e-1
,
torch
.
bfloat16
:
1e4
,
torch
.
float
:
2e1
}
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atols
[
dtype
],
rtol
=
rtols
[
dtype
]
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
sgl-kernel/tests/test_moe_align.py
View file @
8fdcd98e
...
...
@@ -4,7 +4,14 @@ import pytest
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
,
moe_sum
def
is_hip
()
->
bool
:
return
torch
.
version
.
hip
is
not
None
_is_hip
=
is_hip
()
def
ceil_div
(
a
,
b
):
...
...
@@ -246,5 +253,20 @@ def test_moe_align_block_size_compare_implementations(
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
6
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
skipif
(
_is_hip
,
reason
=
"Skip for AMD GPU"
)
def
test_moe_sum
(
m
:
int
,
topk
:
int
,
k
:
int
,
dtype
:
torch
.
dtype
):
input
=
torch
.
randn
((
m
,
topk
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
actual
=
torch
.
empty
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
expected
=
input
.
sum
(
dim
=
1
)
moe_sum
(
input
,
actual
)
torch
.
testing
.
assert_close
(
actual
,
expected
,
atol
=
2e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment