Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad385667
Commit
ad385667
authored
Oct 23, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.6.3.post1-dev'
parents
be0967c1
903593d3
Changes
364
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6147 additions
and
238 deletions
+6147
-238
csrc/ops.h
csrc/ops.h
+94
-59
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+20
-14
csrc/permute_cols.cu
csrc/permute_cols.cu
+88
-0
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+213
-27
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+13
-12
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+193
-22
csrc/quantization/cutlass_w8a8/Epilogues.md
csrc/quantization/cutlass_w8a8/Epilogues.md
+147
-0
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
...quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
+151
-1
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
+57
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+217
-36
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+230
-28
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+142
-19
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+46
-20
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+6
-0
csrc/quantization/gguf/dequantize.cuh
csrc/quantization/gguf/dequantize.cuh
+568
-0
csrc/quantization/gguf/ggml-common.h
csrc/quantization/gguf/ggml-common.h
+1115
-0
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+247
-0
csrc/quantization/gguf/mmq.cuh
csrc/quantization/gguf/mmq.cuh
+600
-0
csrc/quantization/gguf/mmvq.cuh
csrc/quantization/gguf/mmvq.cuh
+190
-0
csrc/quantization/gguf/vecdotq.cuh
csrc/quantization/gguf/vecdotq.cuh
+1810
-0
No files found.
Too many changes to show.
To preserve performance only
364 of 364+
files are displayed.
Plain diff
Email patch
csrc/ops.h
View file @
ad385667
...
...
@@ -47,6 +47,27 @@ void paged_attention_v2_opt(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
...
...
@@ -96,21 +117,32 @@ void gelu_quick(torch::Tensor& out, torch::Tensor& input);
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
);
void
advance_step
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
);
void
advance_step_flashattn
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
);
void
advance_step_flashinfer
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
#ifndef USE_ROCM
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
);
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
...
...
@@ -121,38 +153,16 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch
::
Tensor
_zeros
,
int64_t
split_k_iters
,
int64_t
thx
,
int64_t
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
...
...
@@ -161,30 +171,29 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
torch
::
Tensor
marlin_qqq_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
_q_weight
,
torch
::
Tensor
const
&
s_tok
,
torch
::
Tensor
const
&
s_ch
,
torch
::
Tensor
const
&
s_group
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
torch
::
Tensor
const
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
);
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scales
);
torch
::
Tensor
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
);
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
// torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
// torch::Tensor b_gptq_qzeros,
// torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
// bool use_exllama, int64_t bit);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int64_t
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale);
...
...
@@ -201,14 +210,40 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
c10
::
optional
<
torch
::
Tensor
>&
D_
,
const
c10
::
optional
<
torch
::
Tensor
>&
z_
,
const
c10
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
c10
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
void
causal_conv1d_update
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
conv_state
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
bool
silu_activation
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_seqlens_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_state_indices_
,
int64_t
pad_slot_id
);
void
causal_conv1d_fwd
(
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
weight
,
const
c10
::
optional
<
at
::
Tensor
>&
bias_
,
const
c10
::
optional
<
at
::
Tensor
>&
conv_states
,
const
c10
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
c10
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
c10
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int64_t
max_size
,
int64_t
world_size
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
...
...
csrc/opt/layernorm_kernels_opt.cu
View file @
ad385667
...
...
@@ -6,13 +6,17 @@
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "../dispatch_utils.h"
#include "../reduction_utils.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
...
...
@@ -34,7 +38,11 @@ __global__ void rms_norm_kernel(
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
}
variance
=
blockReduceSum
<
float
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
@@ -231,12 +239,11 @@ fused_add_rms_norm_kernel(
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
@@ -271,12 +278,11 @@ fused_add_rms_norm_kernel(
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
csrc/permute_cols.cu
0 → 100644
View file @
ad385667
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
static
constexpr
int
default_threads
=
256
;
static
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
// Currently only supports 16bit types (since we permute half types)
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
}
int
cur_block_rows
=
std
::
max
(
finish_row
-
start_row
,
0
);
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
default_threads
;
int
rest
=
size_k
%
default_threads
;
int
offset
=
row
*
row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
base_k
+=
default_threads
;
}
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
}
}
};
for
(
int
i
=
0
;
i
<
cur_block_rows
;
i
++
)
{
int
cur_row
=
start_row
+
i
;
if
(
cur_row
<
size_m
)
{
permute_row
(
cur_row
);
}
}
}
// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
auto
dev
=
A
.
get_device
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
TORCH_CHECK
(
A
.
scalar_type
()
==
at
::
kHalf
||
A
.
scalar_type
()
==
at
::
kBFloat16
,
"Currently only 16bit types are supported"
);
TORCH_CHECK
(
A
.
is_contiguous
(),
"A must be contiguous"
);
TORCH_CHECK
(
A
.
size
(
-
1
)
%
8
==
0
,
"A columns must be a multiple of 8 (128bits)"
);
auto
A_2d
=
A
.
view
({
-
1
,
A
.
size
(
-
1
)});
torch
::
Tensor
D
=
torch
::
empty_like
(
A
);
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
block_rows
=
div_ceil
(
A_2d
.
size
(
0
),
sms
);
permute_cols_kernel
<<<
sms
,
default_threads
,
0
,
stream
>>>
(
reinterpret_cast
<
int4
const
*>
(
A_2d
.
const_data_ptr
()),
perm
.
const_data_ptr
<
int
>
(),
reinterpret_cast
<
int4
*>
(
D
.
mutable_data_ptr
()),
A_2d
.
size
(
0
),
A_2d
.
size
(
1
),
block_rows
);
return
D
;
}
\ No newline at end of file
csrc/prepare_inputs/advance_step.cu
View file @
ad385667
...
...
@@ -12,13 +12,22 @@ namespace prepare_inputs {
//
template
<
int
const
num_threads
>
__global__
void
advance_step_kernel
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
)
{
__global__
void
advance_step_flashattn_kernel
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
)
{
int
const
n_pad
=
num_seqs
-
num_queries
;
if
(
n_pad
&&
blockIdx
.
x
==
0
)
{
// Handle cuda graph padding
int
const
offset
=
num_queries
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_pad
;
i
+=
blockDim
.
x
)
{
input_tokens_ptr
[
offset
+
i
]
=
0
;
input_positions_ptr
[
offset
+
i
]
=
0
;
slot_mapping_ptr
[
offset
+
i
]
=
-
1
;
}
}
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
>=
num_query_blocks
)
{
...
...
@@ -54,7 +63,7 @@ __global__ void advance_step_kernel(int num_seqs, int num_queries,
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
}
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
&
t
,
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
const
&
t
,
int64_t
const
size_0
,
int64_t
const
size_1
,
c10
::
ScalarType
const
type
)
{
bool
size_0_cond
=
true
;
...
...
@@ -79,16 +88,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
}
}
void
advance_step
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
torch
::
Tensor
&
input_tokens
,
// type: long
torch
::
Tensor
&
sampled_token_ids
,
// type: long
torch
::
Tensor
&
input_positions
,
// type: long
torch
::
Tensor
&
seq_lens
,
// type: int
torch
::
Tensor
&
slot_mapping
,
// type: long
torch
::
Tensor
&
block_tables
)
{
// type: int
__global__
void
advance_step_flashinfer_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_last_page_len_ptr
,
int
*
block_table_bound_ptr
)
{
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
<
num_query_blocks
)
{
int
cur_query_id
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
if
(
cur_query_id
<
num_queries
)
{
// Update input_tokens
input_tokens_ptr
[
cur_query_id
]
=
sampled_token_ids_ptr
[
cur_query_id
];
int
seq_len
=
seq_lens_ptr
[
cur_query_id
];
int
next_seq_len
=
seq_len
+
1
;
int
next_input_pos
=
next_seq_len
-
1
;
// Update seq_lens
seq_lens_ptr
[
cur_query_id
]
=
next_seq_len
;
// Update input_positions
input_positions_ptr
[
cur_query_id
]
=
next_input_pos
;
int
const
*
seq_block_tables_ptr
=
block_tables_ptr
+
block_tables_stride
*
cur_query_id
;
int
block_index
=
next_input_pos
/
block_size
;
int
block_offset
=
next_input_pos
%
block_size
;
// Update paged_kv_last_page_len
paged_kv_last_page_len_ptr
[
cur_query_id
]
=
block_offset
+
1
;
int
slot_num
=
seq_block_tables_ptr
[
block_index
]
*
block_size
+
block_offset
;
// Update slot_mapping
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
block_table_bound_ptr
[
cur_query_id
]
=
div_ceil
(
next_seq_len
,
block_size
);
}
}
}
__global__
void
advance_step_flashinfer_indptr_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
// Update paged_kv_indptr
if
(
idx
<
num_queries
)
{
int
sum
=
0
;
for
(
int
i
=
0
;
i
<=
idx
;
++
i
)
{
sum
+=
block_table_bound_ptr
[
i
];
}
paged_kv_indptr_ptr
[
idx
+
1
]
=
sum
;
}
}
__global__
void
advance_step_flashinfer_indices_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_indices_ptr
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
int
row
=
idx
/
block_tables_stride
;
int
col
=
idx
%
block_tables_stride
;
if
(
row
<
num_queries
&&
col
<
block_table_bound_ptr
[
row
])
{
paged_kv_indices_ptr
[
paged_kv_indptr_ptr
[
row
]
+
col
]
=
block_tables_ptr
[
row
*
block_tables_stride
+
col
];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if
(
num_queries
<
row
&&
row
<=
num_seqs
)
{
paged_kv_indptr_ptr
[
row
]
=
paged_kv_indptr_ptr
[
num_queries
];
}
}
void
advance_step_flashattn
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
torch
::
Tensor
&
input_tokens
,
// type: long
torch
::
Tensor
&
sampled_token_ids
,
// type: long
torch
::
Tensor
&
input_positions
,
// type: long
torch
::
Tensor
&
seq_lens
,
// type: int
torch
::
Tensor
&
slot_mapping
,
// type: long
torch
::
Tensor
&
block_tables
)
{
// type: int
if
(
logging
)
{
printf
(
"advance_step:
\n
"
);
printf
(
"advance_step
_flashattn
:
\n
"
);
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
...
...
@@ -108,24 +192,126 @@ void advance_step(int num_seqs, int num_queries, int block_size,
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
advance_step_kernel
<
max_threads
><<<
blocks
,
max_threads
,
0
,
stream
>>>
(
num_seqs
,
num_queries
,
block_size
,
advance_step_flashattn_kernel
<
max_threads
>
<<<
blocks
,
max_threads
,
0
,
stream
>>>
(
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
reinterpret_cast
<
long
const
*>
(
sampled_token_ids
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
input_positions
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
slot_mapping
.
data_ptr
()),
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
));
}
void
advance_step_flashinfer
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
torch
::
Tensor
&
input_tokens
,
// type: long
torch
::
Tensor
&
sampled_token_ids
,
// type: long
torch
::
Tensor
&
input_positions
,
// type: long
torch
::
Tensor
&
seq_lens
,
// type: int
torch
::
Tensor
&
slot_mapping
,
// type: long
torch
::
Tensor
&
block_tables
,
// type: int
torch
::
Tensor
&
paged_kv_indices
,
// type: int
torch
::
Tensor
&
paged_kv_indptr
,
// type: int
torch
::
Tensor
&
paged_kv_last_page_len
,
// type: int
torch
::
Tensor
&
block_table_bound
)
{
// type: int
if
(
logging
)
{
printf
(
"advance_step_flashinfer:
\n
"
);
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
printf
(
" block_tables.stride(0) = %zu
\n
"
,
block_tables
.
stride
(
0
));
}
// Verify all tensors
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
// at::kLong);
verify_tensor
(
"input_positions"
,
input_positions
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"seq_lens"
,
seq_lens
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"slot_mapping"
,
slot_mapping
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"block_tables"
,
block_tables
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_indices"
,
paged_kv_indices
,
-
1
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_indptr"
,
paged_kv_indptr
,
num_seqs
+
1
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_last_page_len"
,
paged_kv_last_page_len
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"block_table_bound"
,
block_table_bound
,
num_seqs
,
-
1
,
at
::
kInt
);
int
dev
=
sampled_token_ids
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
int
threads
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
cudaDeviceGetAttribute
(
&
threads
,
cudaDevAttrMaxThreadsPerBlock
,
dev
);
if
(
logging
)
{
printf
(
"launching kernel with %d blocks
\n
"
,
blocks
);
}
// TODO(will): support arbitrary block_tables stride
if
((
blocks
*
threads
)
/
block_tables
.
stride
(
0
)
<
num_queries
)
{
TORCH_CHECK
(
false
,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,"
,
" increasing the block size or take smaller steps."
,
" num_queries = "
,
num_queries
,
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
}
advance_step_flashinfer_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
reinterpret_cast
<
long
const
*>
(
sampled_token_ids
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
input_positions
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
slot_mapping
.
data_ptr
()),
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
));
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_last_page_len
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indptr_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
reinterpret_cast
<
int
*>
(
paged_kv_indptr
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indices_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_indices
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
paged_kv_indptr
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
}
}
// namespace prepare_inputs
void
advance_step
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
)
{
prepare_inputs
::
advance_step
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
);
}
\ No newline at end of file
void
advance_step_flashattn
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
)
{
prepare_inputs
::
advance_step_flashattn
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
);
}
void
advance_step_flashinfer
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bound
)
{
prepare_inputs
::
advance_step_flashinfer
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
);
}
csrc/quantization/aqlm/gemm_kernels.cu
View file @
ad385667
...
...
@@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input,
}
// Accumulate the partition sizes.
int4
accumulate_sizes
(
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
accumulate_sizes
(
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
;
auto
cumulative_size
=
&
cumulative_sizes
.
x
;
in
t
i
=
0
;
size_
t
i
=
0
;
int
last
=
0
;
assert
(
codebook_partition_sizes
.
size
(
0
)
<=
4
);
for
(;
i
<
codebook_partition_sizes
.
size
(
0
);
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
codebook_partition_sizes
[
i
]
.
item
<
int
>
()
+
last
;
assert
(
codebook_partition_sizes
.
size
()
<=
4
);
for
(;
i
<
codebook_partition_sizes
.
size
();
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
codebook_partition_sizes
[
i
]
+
last
;
last
=
*
cumulative_size
;
}
// fill in the rest with unreachable.
...
...
@@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
(
0
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
();
int
const
entries
=
codebooks
.
size
(
1
);
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
...
...
@@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
return
{};
}
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
(
0
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
();
int
const
entries
=
codebooks
.
size
(
1
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
codes
));
...
...
@@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
auto
in_features
=
codes
.
size
(
1
)
*
8
;
auto
out_features
=
codes
.
size
(
0
);
assert
(
out_features
=
codebook_partition_sizes
.
sum
().
item
<
int
>
());
assert
(
out_features
==
std
::
accumulate
(
codebook_partition_sizes
.
begin
(),
codebook_partition_sizes
.
end
(),
0
));
auto
weights
=
torch
::
empty
({
out_features
,
in_features
},
torch
::
TensorOptions
()
...
...
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
ad385667
...
...
@@ -3,16 +3,28 @@
#include <cmath>
#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
#ifdef USE_ROCM
static
const
float
i8_min
=
static
const
expr
auto
i8_min
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
static
const
float
i8_max
=
static
const
expr
auto
i8_max
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// round
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float
dst
=
std
::
nearbyint
(
x
);
// saturate
dst
=
std
::
clamp
(
dst
,
i8_min
,
i8_max
);
return
static_cast
<
int8_t
>
(
dst
);
...
...
@@ -24,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
#endif
}
static
inline
__device__
int32_t
float_to_int32_rn
(
float
x
)
{
#ifdef USE_ROCM
// int32_max is not exactly representable as float.
// Therefore, we need to be careful and manually return int32_max on overflow.
// For symmetry, we also do the same for int32_min, even though it is exactly
// representable as float and the conversion should be exact.
static
constexpr
auto
i32_min
=
std
::
numeric_limits
<
int32_t
>::
min
();
static
constexpr
auto
i32_min_f
=
static_cast
<
float
>
(
i32_min
);
static
constexpr
auto
i32_max
=
std
::
numeric_limits
<
int32_t
>::
max
();
static
constexpr
auto
i32_max_f
=
static_cast
<
float
>
(
i32_max
);
// To match the rounding mode of CUDA, we use nearbyint.
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
// If that changes in the future, we may need to set the rounding mode
// explicitly, either at runtime or compile time.
float
dst
=
std
::
nearbyint
(
x
);
// saturate on the higher end.
if
(
dst
>=
i32_max_f
)
{
return
i32_max
;
}
// saturate on the lower end.
if
(
dst
<=
i32_min_f
)
{
return
i32_min
;
}
return
static_cast
<
int32_t
>
(
dst
);
#else
// CUDA path
uint32_t
dst
;
asm
volatile
(
"cvt.rni.sat.s32.f32 %0, %1;"
:
"=r"
(
dst
)
:
"f"
(
x
));
return
reinterpret_cast
<
const
int32_t
&>
(
dst
);
#endif
}
static
inline
__device__
int8_t
int32_to_int8
(
int32_t
x
)
{
#ifdef USE_ROCM
static
constexpr
auto
i8_min
=
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
static
constexpr
auto
i8_max
=
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// saturate
int32_t
dst
=
std
::
clamp
(
x
,
i8_min
,
i8_max
);
return
static_cast
<
int8_t
>
(
dst
);
#else
// CUDA path
uint32_t
dst
;
asm
volatile
(
"cvt.sat.s8.s32 %0, %1;"
:
"=r"
(
dst
)
:
"r"
(
x
));
return
reinterpret_cast
<
const
int8_t
&>
(
dst
);
#endif
}
namespace
vllm
{
template
<
typename
scalar_t
,
typename
scale_type
>
...
...
@@ -31,12 +96,36 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
/
scale
);
}
}
template
<
typename
scalar_t
,
typename
scale_type
,
typename
azp_type
>
__global__
void
static_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
/
scale
);
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
out
[
i
]
=
quant_val
;
}
}
...
...
@@ -45,17 +134,24 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
int
64_t
const
token_idx
=
blockIdx
.
x
;
float
absmax_val
=
0.0
f
;
float
const
zero
=
0.0
f
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
val
=
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
]);
float
val
=
static_cast
<
float
>
(
input
[
i
]);
val
=
val
>
zero
?
val
:
-
val
;
absmax_val
=
val
>
absmax_val
?
val
:
absmax_val
;
}
float
const
block_absmax_val_maybe
=
blockReduceMax
(
absmax_val
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
float
const
block_absmax_val_maybe
=
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
block_absmax_val
;
if
(
tid
==
0
)
{
block_absmax_val
=
block_absmax_val_maybe
;
...
...
@@ -65,8 +161,63 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float
const
tmp_scale
=
127.0
f
/
block_absmax_val
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
token_idx
*
hidden_size
+
i
])
*
tmp_scale
);
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
*
tmp_scale
);
}
}
template
<
typename
scalar_t
,
typename
scale_type
,
typename
azp_type
>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
int64_t
const
token_idx
=
blockIdx
.
x
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
// Scan for the min and max value for this token
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
val
=
static_cast
<
float
>
(
input
[
i
]);
max_val
=
std
::
max
(
max_val
,
val
);
min_val
=
std
::
min
(
min_val
,
val
);
}
// Reduce the max and min values across the block
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
max_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
max_val
,
cub
::
Max
{},
blockDim
.
x
);
__syncthreads
();
// Make sure min doesn't mess with max shared memory
min_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
min_val
,
cub
::
Min
{},
blockDim
.
x
);
__shared__
scale_type
scale_sh
;
__shared__
azp_type
azp_sh
;
// Compute the scale and zero point and store them, only on the first thread
if
(
threadIdx
.
x
==
0
)
{
float
const
scale_val
=
(
max_val
-
min_val
)
/
255.0
f
;
// Use rounding to even (same as torch.round)
auto
const
azp_float
=
std
::
nearbyint
(
-
128.0
f
-
min_val
/
scale_val
);
auto
const
azp_val
=
static_cast
<
azp_type
>
(
azp_float
);
// Store the scale and azp into shared and global
scale
[
token_idx
]
=
scale_sh
=
scale_val
;
azp
[
token_idx
]
=
azp_sh
=
azp_val
;
}
// Wait for the scale and azp to be computed
__syncthreads
();
float
const
scale_val
=
scale_sh
;
azp_type
const
azp_val
=
azp_sh
;
// Quantize the values
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
out
[
i
]
=
quant_val
;
}
}
...
...
@@ -74,10 +225,12 @@ __global__ void dynamic_scaled_int8_quant_kernel(
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
scale
)
{
torch
::
Tensor
const
&
scale
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scale
.
numel
()
==
1
);
TORCH_CHECK
(
!
azp
||
azp
->
numel
()
==
1
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -86,19 +239,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"static_scaled_int8_quant_kernel"
,
[
&
]
{
vllm
::
static_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
hidden_size
);
if
(
!
azp
)
{
vllm
::
static_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
hidden_size
);
}
else
{
vllm
::
static_scaled_int8_azp_quant_kernel
<
scalar_t
,
float
,
int32_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
hidden_size
);
}
});
}
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
scales
)
{
torch
::
Tensor
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scales
.
is_contiguous
());
TORCH_CHECK
(
!
azp
||
azp
->
is_contiguous
());
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -107,9 +270,17 @@ void dynamic_scaled_int8_quant(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_scaled_int8_quant_kernel"
,
[
&
]
{
vllm
::
dynamic_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
hidden_size
);
if
(
!
azp
)
{
vllm
::
dynamic_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
hidden_size
);
}
else
{
vllm
::
dynamic_scaled_int8_azp_quant_kernel
<
scalar_t
,
float
,
int32_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scales
.
data_ptr
<
float
>
(),
azp
->
data_ptr
<
int32_t
>
(),
hidden_size
);
}
});
}
csrc/quantization/cutlass_w8a8/Epilogues.md
0 → 100644
View file @
ad385667
# CUTLASS Epilogues
## Introduction
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
Currently, we only support symmetric quantization for weights,
and symmetric and asymmetric quantization for activations.
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
There are 4 epilogues:
1.
ScaledEpilogue: symmetric quantization for activations, no bias.
1.
ScaledEpilogueBias: symmetric quantization for activations, supports bias.
1.
ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias.
1.
ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias.
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
Instead, if no bias is passed, the epilogue will use 0 as the bias.
That induces a redundant addition operation (and runtime check), but the performance impact is minor.
## Underlying Linear Algebra
More details available in the
[
Activation Quantization RFC
](
https://github.com/vllm-project/vllm/issues/3975
)
.
If $
` \widehat X `
$ is the quantized $
` X `
$, our matrices become the following
```
math
A = s_a (\widehat A - J_a z_a)
```
```
math
B = s_b \widehat B
```
```
math
D = A B + C
```
```
math
D = s_a s_b \widehat D + C
```
Here, D is the output of the GEMM, and C is the bias.
A is the activations and supports asymmetric quantization,
and B is the weights and only supports symmetric quantization.
$ s_a $ and $s_b$ are the scales for activations and weights, respectively.
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
Additional epilogues would be required to support asymmetric quantization for weights.
Expanding further, we can calculate $
` \widehat D `
$ as follows:
```
math
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
```
```
math
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
```
```
math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```
Note that $
` \widehat A \widehat B `
$ is the raw output of the GEMM,
and $
` J_a \widehat B `
$ is known ahead of time.
Each row of it is equal to $
` \mathbf 1 \widehat B `
$, which is a row-vector of column sums of $
` \widehat B `
$.
## Epilogues
### ScaledEpilogue
This epilogue computes the symmetric quantization for activations without bias, meaning $
` C = 0 `
$ and $
` z_a = 0 `
$.
The output of the GEMM is:
```
math
\widehat D = \widehat A \widehat B
```
```
math
D = s_a s_b \widehat D
```
```
math
D = s_a s_b \widehat A \widehat B
```
Epilogue parameters:
-
`scale_a`
is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
-
`scale_b`
is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
### ScaledEpilogueBias
This epilogue computes the symmetric quantization for activations with bias, meaning $
` z_a = 0 `
$.
The output of the GEMM is:
```
math
\widehat D = \widehat A \widehat B
```
```
math
D = s_a s_b \widehat D + C
```
```
math
D = s_a s_b \widehat A \widehat B + C
```
Epilogue parameters:
-
`scale_a`
is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
-
`scale_b`
is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
-
`bias`
is the bias, is always per-channel (row-vector).
### ScaledEpilogueAzp
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
The output of the GEMM is:
```
math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```
```
math
D = s_a s_b \widehat D + C
```
```
math
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
```
Because $
` z_a `
$ is a scalar, the zero-point term $
` z_a J_a \widehat B `
$ has every row equal to $
` z_a \mathbf 1 B `
$.
That is precomputed and stored in
`azp_with_adj`
as a row-vector.
Epilogue parameters:
-
`scale_a`
is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
-
Generally this will be per-tensor as the zero-points are per-tensor.
-
`scale_b`
is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
-
`azp_with_adj`
is the precomputed zero-point term ($
` z_a J_a \widehat B `
$), is per-channel (row-vector).
-
`bias`
is the bias, is always per-channel (row-vector).
To use these kernels efficiently, users must precompute the
`azp_with_adj`
term offline and pass it to the kernel.
### ScaledEpilogueAzpPerToken
This epilogue computes the asymmetric per-token quantization for activations with bias.
The output of the GEMM is the same as above, but the $
` z_a `
$ is a column-vector.
That means the zero-point term $
` z_a J_a \widehat B `
$ becomes an outer product of $
` z_a `
$ and $
` \mathbf 1 \widehat B `
$.
Epilogue parameters:
-
`scale_a`
is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
-
Generally this will be per-token as the zero-points are per-token.
-
`scale_b`
is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
-
`azp_adj`
is the precomputed zero-point adjustment term ($
` \mathbf 1 \widehat B `
$), is per-channel (row-vector).
-
`azp`
is the zero-point (
`z_a`
), is per-token (column-vector).
-
`bias`
is the bias, is always per-channel (row-vector).
To use these kernels efficiently, users must precompute the
`azp_adj`
term offline and pass it to the kernel.
The epilogue performs the following computation (where
`Dq`
is the raw quantized output of the GEMM):
```
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
```
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
View file @
ad385667
...
...
@@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast {
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
template
<
class
ThreadMap
,
class
Element
,
class
StrideMNL
>
struct
VisitorRowOrZeroBroadcast
{
// This struct has been modified to remove null_default (because it's always 0)
struct
Arguments
{
Element
const
*
ptr_row
=
nullptr
;
StrideMNL
dRow
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
struct
SharedStorage
{};
// Global load type
static
int
constexpr
vec_bits
=
ThreadMap
::
kElementsPerAccess
*
sizeof_bits
<
Element
>::
value
;
using
VecType
=
uint_bit_t
<
cute
::
min
(
128
,
vec_bits
)
>
;
static
int
constexpr
VecLength
=
sizeof
(
VecType
)
/
sizeof
(
Element
);
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast
()
{
}
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params_ptr
(
&
params
)
{
}
Params
const
*
params_ptr
;
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
Callbacks
:
EmptyCallbacks
{
CUTLASS_DEVICE
Callbacks
(
GTensor
&&
tC_gRow
,
RTensor
&&
tC_rRow
,
CTensor
&&
tC_cRow
,
ProblemShape
problem_shape
,
Params
const
*
params_ptr
)
:
tC_gRow
(
cute
::
forward
<
GTensor
>
(
tC_gRow
)),
tC_rRow
(
cute
::
forward
<
RTensor
>
(
tC_rRow
)),
tC_cRow
(
cute
::
forward
<
CTensor
>
(
tC_cRow
)),
n
(
get
<
1
>
(
problem_shape
)),
params_ptr
(
params_ptr
)
{
}
GTensor
tC_gRow
;
RTensor
tC_rRow
;
CTensor
tC_cRow
;
Params
const
*
params_ptr
;
int
n
;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE
void
begin_epilogue
()
{
clear
(
tC_rRow
);
auto
src_v
=
filter
(
tC_gRow
);
auto
coord_v
=
filter
(
tC_cRow
);
auto
dst_v
=
filter
(
tC_rRow
);
if
(
params_ptr
->
ptr_row
!=
nullptr
)
{
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
bool
guard
=
get
<
1
>
(
coord_v
(
i
))
<
n
;
cutlass
::
arch
::
global_load
<
VecType
,
sizeof
(
VecType
)
>
(
dst_v
(
i
),
(
void
const
*
)
&
src_v
(
i
),
guard
);
}
}
else
{
// In this case we are broadcasting 0
VecType
filled_vec
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
VecLength
;
i
++
)
{
reinterpret_cast
<
Element
*>
(
&
filled_vec
)[
i
]
=
Element
{
0
};
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
if
(
get
<
1
>
(
coord_v
(
i
))
<
n
)
{
dst_v
(
i
)
=
filled_vec
;
}
}
}
}
template
<
class
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
auto
// returns an Array
visit
(
int
iter_idx
,
int
row_idx
,
int
column_idx
,
int
frg_idx
,
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
)
{
Tensor
rRow_frg
=
recast
<
Array
<
Element
,
FragmentSize
>>
(
coalesce
(
tC_rRow
));
return
rRow_frg
(
column_idx
);
}
};
template
<
class
ProblemShape
>
CUTLASS_DEVICE
auto
get_callbacks
(
gemm
::
GemmCoord
threadblock_tile_offset
,
int
thread_idx
,
ProblemShape
problem_shape
)
{
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params_ptr
->
ptr_row
),
problem_shape
,
params_ptr
->
dRow
);
// VECTOR, FRAGMENT_COLUMN
Tensor
tC_gRow
=
recast
<
VecType
>
(
ThreadMap
::
partition
(
mRow
,
thread_idx
,
threadblock_tile_offset
)
)(
_
,
_
,
_0
{},
_0
{},
_0
{},
_0
{});
Tensor
tC_rRow
=
make_tensor_like
(
tC_gRow
);
// Generate the pred tensor
Tensor
cRow
=
make_identity_tensor
(
mRow
.
shape
());
Tensor
tC_cRow
=
outer_partition
(
ThreadMap
::
partition
(
cRow
,
thread_idx
,
threadblock_tile_offset
)(
_
,
_
,
_0
{},
_0
{},
_0
{},
_0
{}),
Shape
<
Int
<
VecLength
>>
{},
(
_0
{})
);
return
Callbacks
<
decltype
(
tC_gRow
),
decltype
(
tC_rRow
),
decltype
(
tC_cRow
),
ProblemShape
>
(
cute
::
move
(
tC_gRow
),
cute
::
move
(
tC_rRow
),
cute
::
move
(
tC_cRow
),
problem_shape
,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
...
...
@@ -217,7 +367,7 @@ template<
>
struct
VisitorColOrScalarBroadcast
{
// This struct has been modified to have a bool indicating that ptr_col is a
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct
Arguments
{
Element
const
*
ptr_col
=
nullptr
;
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
View file @
ad385667
...
...
@@ -50,6 +50,25 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
}
}
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm75_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm80_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
@@ -87,6 +106,25 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
}
}
void
cutlass_scaled_mm_azp_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm80_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
template
<
template
<
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm89_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
@@ -139,3 +177,22 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
void
cutlass_scaled_mm_azp_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm89_epilogue
<
vllm
::
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
View file @
ad385667
...
...
@@ -73,19 +73,63 @@ struct enable_sm89_to_sm90 : Kernel {
};
/*
* This class provides the common
ScaleA and ScaleB
descriptors for the
* ScaledEpilogue
and ScaledEpilogueBias
classes
.
* This class provides the common
load
descriptors for the
* ScaledEpilogue
[...]
classes
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
ColLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorColBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrZeroLoad
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrZeroBroadcast
<
OutputTileThreadMap
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoad
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoad
<
T
>>
)
{
return
Arguments
{
data_ptr
,
tensor
.
numel
()
!=
1
};
}
else
{
// it would technically work but no use case as data_ptr is never nullptr
static_assert
(
!
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
return
Arguments
{
data_ptr
};
}
}
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
static_assert
(
std
::
is_same_v
<
Descriptor
,
RowOrZeroLoad
<
T
>>
);
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
return
Arguments
{
data_ptr
};
}
};
/*
...
...
@@ -110,8 +154,8 @@ struct ScaledEpilogue
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>
;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
...
...
@@ -131,28 +175,32 @@ struct ScaledEpilogue
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
};
return
evt_compute_args
;
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBias
:
pr
ivate
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
pr
ivate
:
:
pr
otected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
pr
otected
:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>
;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>
;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
>;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
...
...
@@ -164,30 +212,163 @@ struct ScaledEpilogueBias
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowBroadcast
<
OutputTileThreadMap
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleAArgs
=
typename
ScaleA
::
Arguments
;
using
ScaleBArgs
=
typename
ScaleB
::
Arguments
;
using
BiasArgs
=
typename
Bias
::
Arguments
;
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBiasAzp
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowOrZeroLoad
<
ElementD
>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using
AzpWithAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute float(accum - azp_adj), both operands are int32_t
using
ComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAzp
,
Accum
,
AzpWithAdj
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementD
,
typename
OutputTileThreadMap
>
struct
ScaledEpilogueBiasAzpToken
:
protected
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementD
,
OutputTileThreadMap
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowOrZeroLoad
<
ElementD
>;
// Per-token azp term, shape (m,1)
using
Azp
=
typename
SUPER
::
template
ColLoad
<
int32_t
>;
ScaleBArgs
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
ScaleAArgs
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
BiasArgs
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
()),
{}};
// This is the AZP adjustment term, J @ B, shape (1,n)
using
AzpAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute azp * azp_adj
using
ComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
int32_t
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
typename
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
using
EVTComputeAzp
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAzp
,
Azp
,
AzpAdj
>
;
typename
EVTCompute
::
Arguments
evt_compute_args
{
a_args
,
evt0_compute_args
,
bias_args
};
return
evt_compute_args
;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using
ComputeAcc
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAcc
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeAcc
,
Accum
,
EVTComputeAzp
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_args
=
SUPER
::
template
args_from_tensor
<
Azp
,
int32_t
>(
azp
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
ad385667
...
...
@@ -58,21 +58,63 @@ struct enable_sm90_or_later : Kernel {
};
/*
* This class provides the common
ScaleA and ScaleB
descriptors for the
* ScaledEpilogue
and ScaledEpilogueBias
classes
.
* This class provides the common
load
descriptors for the
* ScaledEpilogue
[...]
classes
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
Descriptor
,
ColOrScalarLoad
<
T
>>
||
std
::
is_same_v
<
Descriptor
,
RowOrScalarLoad
<
T
>>
)
{
return
Arguments
{
data_ptr
,
tensor
.
numel
()
!=
1
};
}
else
{
static_assert
(
!
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
&&
!
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
tensor
?
static_cast
<
T
*>
(
tensor
->
data_ptr
())
:
nullptr
;
static_assert
(
std
::
is_same_v
<
Descriptor
,
ColLoad
<
T
,
true
>>
||
std
::
is_same_v
<
Descriptor
,
RowLoad
<
T
,
true
>>
);
return
Arguments
{
data_ptr
};
}
};
/*
...
...
@@ -97,8 +139,8 @@ struct ScaledEpilogue
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>
;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
...
...
@@ -118,24 +160,32 @@ struct ScaledEpilogue
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ScaleA_Args
=
typename
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
ScaleB
::
Arguments
;
ScaleA_Args
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
ScaleB_Args
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
return
ArgumentType
{
a_args
,
{
b_args
}};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
ScaleA
;
using
ScaleB
=
typename
SUPER
::
ScaleB
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
>;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
...
...
@@ -148,27 +198,160 @@ struct ScaledEpilogueBias
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
ElementD
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
ElementD
>
,
false
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
bias
)
{
using
ScaleA_Args
=
typename
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
ScaleB
::
Arguments
;
using
Bias_Args
=
typename
Bias
::
Arguments
;
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBiasAzp
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
,
true
>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using
AzpWithAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute float(accum - azp_adj), both operands are int32_t
using
ComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAzp
,
Accum
,
AzpWithAdj
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAzp
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
struct
ScaledEpilogueBiasAzpToken
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
Bias
=
typename
SUPER
::
template
RowLoad
<
ElementD
,
true
>;
// Per-token azp term, shape (m,1)
using
Azp
=
typename
SUPER
::
template
ColLoad
<
int32_t
>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using
AzpAdj
=
typename
SUPER
::
template
RowLoad
<
int32_t
>;
// Compute azp * azp_adj
using
ComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
int32_t
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeAzp
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAzp
,
Azp
,
AzpAdj
>
;
ScaleA_Args
a_args
{
a_scales
.
data_ptr
<
float
>
(),
a_scales
.
numel
()
!=
1
,
{}};
ScaleB_Args
b_args
{
b_scales
.
data_ptr
<
float
>
(),
b_scales
.
numel
()
!=
1
,
{}};
Bias_Args
bias_args
{
static_cast
<
ElementD
*>
(
bias
.
data_ptr
())};
// Compute float(accum - azp*azp_adj), all operands are int32_t
using
ComputeAcc
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
minus
,
float
,
int32_t
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
return
ArgumentType
{
a_args
,
{
b_args
},
bias_args
};
using
EVTComputeAcc
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeAcc
,
Accum
,
EVTComputeAzp
>
;
using
ComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleB
,
ScaleB
,
EVTComputeAcc
>
;
using
ComputeScaleBiasA
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
public:
using
EVTCompute
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeScaleBiasA
,
ScaleA
,
EVTComputeScaleB
,
Bias
>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
torch
::
Tensor
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
azp_args
=
SUPER
::
template
args_from_tensor
<
Azp
,
int32_t
>(
azp
);
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
}
};
...
...
@@ -546,4 +729,23 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
}
}
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
azp
)
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBiasAzpToken
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
*
azp
,
bias
);
}
else
{
return
cutlass_scaled_mm_sm90_epilogue
<
ScaledEpilogueBiasAzp
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
bias
);
}
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
ad385667
...
...
@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined
CUDA_VERSION && CUDA_VERSION >= 12000
#if defined
ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -29,6 +29,40 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
...
...
@@ -45,18 +79,20 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return
false
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
int32_t
major_capability
;
int32_t
minor_capability
;
int32_t
get_sm_version_num
()
{
int32_t
major_capability
,
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
0
);
cudaDeviceGetAttribute
(
&
minor_capability
,
cudaDevAttrComputeCapabilityMinor
,
0
);
int32_t
version_num
=
major_capability
*
10
+
minor_capability
;
return
version_num
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
...
...
@@ -77,25 +113,112 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
}
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#else
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
());
}
if
(
azp
)
{
TORCH_CHECK
(
azp
->
numel
()
==
a
.
size
(
0
)
&&
azp
->
is_contiguous
());
}
TORCH_CHECK
(
azp_adj
.
numel
()
==
b
.
size
(
1
)
&&
azp_adj
.
is_contiguous
());
// azp & bias types
TORCH_CHECK
(
azp_adj
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
!
azp
||
azp
->
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
!
bias
||
bias
->
dtype
()
==
c
.
dtype
(),
"currently bias dtype must match output dtype "
,
c
.
dtype
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_azp_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_azp_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
\ No newline at end of file
csrc/quantization/fp8/common.cu
View file @
ad385667
...
...
@@ -7,7 +7,25 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "../../reduction_utils.cuh"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif
#ifndef USE_ROCM
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include "amd/hip_float8.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
namespace
vllm
{
...
...
@@ -21,11 +39,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return
old
;
}
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
c10
::
Float8_e4m3fn
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
__device__
__forceinline__
FP8_TYPE
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
x
=
val
*
scale
;
...
...
@@ -34,7 +50,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#else
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
hip_fp8
(
r
).
data
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
// Compute the absolute maximum m of the input tensor and store
...
...
@@ -74,8 +96,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
std
::
numeric_limits
<
c10
::
Float8_e4m3fn
>::
max
());
atomicMaxFloat
(
scale
,
cache
[
0
]
/
FP8_E4M3_MAX
);
}
}
...
...
@@ -88,10 +109,10 @@ struct __align__(8) vec4_t {
};
typedef
struct
__align__
(
4
)
{
c10
::
Float8_e4m3fn
x
;
c10
::
Float8_e4m3fn
y
;
c10
::
Float8_e4m3fn
z
;
c10
::
Float8_e4m3fn
w
;
FP8_TYPE
x
;
FP8_TYPE
y
;
FP8_TYPE
z
;
FP8_TYPE
w
;
}
float8x4_t
;
...
...
@@ -124,7 +145,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
__device__
void
scaled_fp8_conversion_vec
(
FP8_TYPE
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
...
...
@@ -160,7 +181,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
}
template
<
typename
scalar_t
>
__global__
void
scaled_fp8_quant_kernel
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
__global__
void
scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
float
*
__restrict__
scale
,
int64_t
num_elems
)
{
...
...
@@ -175,7 +196,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
template
<
typename
scalar_t
>
__global__
void
dynamic_per_token_scaled_fp8_quant_kernel
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
float
*
__restrict__
scale
,
FP8_TYPE
*
__restrict__
out
,
float
*
__restrict__
scale
,
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
const
int
hidden_size
)
{
float
const
min_scaling_factor
=
1.0
f
/
(
FP8_E4M3_MAX
*
512.
f
);
...
...
@@ -183,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int
const
tid
=
threadIdx
.
x
;
int
const
token_idx
=
blockIdx
.
x
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
token_idx
*
hidden_size
];
c10
::
Float8_e4m3fn
*
__restrict__
token_output
=
&
out
[
token_idx
*
hidden_size
];
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t
offset
=
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
offset
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
...
...
@@ -200,7 +223,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
}
}
float
const
block_absmax_val_maybe
=
blockReduceMax
(
absmax_val
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
float
const
block_absmax_val_maybe
=
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
token_scale
;
if
(
tid
==
0
)
{
if
(
scale_ub
)
{
...
...
@@ -241,7 +267,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
}
...
...
@@ -261,7 +287,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
vllm
::
segmented_max_reduction
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
scale
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
num_elems
);
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
}
...
...
@@ -284,7 +310,7 @@ void dynamic_per_token_scaled_fp8_quant(
input
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
dynamic_per_token_scaled_fp8_quant_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
scales
.
data_ptr
<
float
>
(),
out
.
data_ptr
<
FP8_TYPE
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
hidden_size
);
...
...
csrc/quantization/fp8/fp8_marlin.cu
View file @
ad385667
...
...
@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using
namespace
marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
}
\ No newline at end of file
csrc/quantization/gguf/dequantize.cuh
0 → 100644
View file @
ad385667
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/convert.cu
// Dequant functions
static
__device__
__forceinline__
void
dequantize_q4_0
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q4_0
*
x
=
(
const
block_q4_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
const
int
vui
=
x
[
ib
].
qs
[
iqs
];
v
.
x
=
__int2half_rn
(
vui
&
0xF
);
v
.
y
=
__int2half_rn
(
vui
>>
4
);
v
=
__hsub2
(
v
,
__floats2half2_rn
(
8.0
f
,
8.0
f
));
v
=
__hmul2
(
v
,
{
d
,
d
});
}
static
__device__
__forceinline__
void
dequantize_q4_1
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q4_1
*
x
=
(
const
block_q4_1
*
)
vx
;
const
dfloat
d
=
__low2half
(
x
[
ib
].
dm
);
const
dfloat
m
=
__high2half
(
x
[
ib
].
dm
);
const
int
vui
=
x
[
ib
].
qs
[
iqs
];
v
.
x
=
__int2half_rn
(
vui
&
0xF
);
v
.
y
=
__int2half_rn
(
vui
>>
4
);
v
=
__hmul2
(
v
,
{
d
,
d
});
v
=
__hadd2
(
v
,
{
m
,
m
});
}
static
__device__
__forceinline__
void
dequantize_q5_0
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q5_0
*
x
=
(
const
block_q5_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
uint32_t
qh
;
memcpy
(
&
qh
,
x
[
ib
].
qh
,
sizeof
(
qh
));
const
int
xh_0
=
((
qh
>>
(
iqs
+
0
))
<<
4
)
&
0x10
;
const
int
xh_1
=
((
qh
>>
(
iqs
+
12
))
)
&
0x10
;
v
.
x
=
__int2half_rn
((
x
[
ib
].
qs
[
iqs
]
&
0xf
)
|
xh_0
);
v
.
y
=
__int2half_rn
((
x
[
ib
].
qs
[
iqs
]
>>
4
)
|
xh_1
);
v
=
__hsub2
(
v
,
__floats2half2_rn
(
16.0
f
,
16.0
f
));
v
=
__hmul2
(
v
,
{
d
,
d
});
}
static
__device__
__forceinline__
void
dequantize_q5_1
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q5_1
*
x
=
(
const
block_q5_1
*
)
vx
;
const
dfloat
d
=
__low2half
(
x
[
ib
].
dm
);
const
dfloat
m
=
__high2half
(
x
[
ib
].
dm
);
uint32_t
qh
;
memcpy
(
&
qh
,
x
[
ib
].
qh
,
sizeof
(
qh
));
const
int
xh_0
=
((
qh
>>
(
iqs
+
0
))
<<
4
)
&
0x10
;
const
int
xh_1
=
((
qh
>>
(
iqs
+
12
))
)
&
0x10
;
v
.
x
=
__int2half_rn
((
x
[
ib
].
qs
[
iqs
]
&
0xf
)
|
xh_0
);
v
.
y
=
__int2half_rn
((
x
[
ib
].
qs
[
iqs
]
>>
4
)
|
xh_1
);
v
=
__hmul2
(
v
,
{
d
,
d
});
v
=
__hadd2
(
v
,
{
m
,
m
});
}
static
__device__
__forceinline__
void
dequantize_q8_0
(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q8_0
*
x
=
(
const
block_q8_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
v
.
x
=
__int2half_rn
(
x
[
ib
].
qs
[
iqs
+
0
]);
v
.
y
=
__int2half_rn
(
x
[
ib
].
qs
[
iqs
+
1
]);
v
=
__hmul2
(
v
,
{
d
,
d
});
}
template
<
int
qk
,
int
qr
,
dequantize_kernel_t
dequantize_kernel
,
typename
dst_t
>
static
__global__
void
dequantize_block
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
y
,
const
int
k
)
{
const
int
i
=
2
*
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
);
if
(
i
>=
k
)
{
return
;
}
const
int
ib
=
i
/
qk
;
// block index
const
int
iqs
=
(
i
%
qk
)
/
qr
;
// quant index
const
int
iybs
=
i
-
i
%
qk
;
// y block start index
const
int
y_offset
=
qr
==
1
?
1
:
qk
/
2
;
// dequantize
dfloat2
v
;
dequantize_kernel
(
vx
,
ib
,
iqs
,
v
);
y
[
iybs
+
iqs
+
0
]
=
v
.
x
;
y
[
iybs
+
iqs
+
y_offset
]
=
v
.
y
;
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q2_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_q2_K
*
x
=
(
const
block_q2_K
*
)
vx
;
const
int
tid
=
threadIdx
.
x
;
const
int
n
=
tid
/
32
;
const
int
l
=
tid
-
32
*
n
;
const
int
is
=
8
*
n
+
l
/
16
;
const
uint8_t
q
=
x
[
i
].
qs
[
32
*
n
+
l
];
dst_t
*
y
=
yy
+
i
*
QK_K
+
128
*
n
;
half
dall
=
__low2half
(
x
[
i
].
dm
);
half
dmin
=
__high2half
(
x
[
i
].
dm
);
y
[
l
+
0
]
=
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
0
]
&
0xF
)
*
((
q
>>
0
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
0
]
>>
4
)));
y
[
l
+
32
]
=
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
2
]
&
0xF
)
*
((
q
>>
2
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
2
]
>>
4
)));
y
[
l
+
64
]
=
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
4
]
&
0xF
)
*
((
q
>>
4
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
4
]
>>
4
)));
y
[
l
+
96
]
=
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
6
]
&
0xF
)
*
((
q
>>
6
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
6
]
>>
4
)));
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q3_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_q3_K
*
x
=
(
const
block_q3_K
*
)
vx
;
const
int
r
=
threadIdx
.
x
/
4
;
const
int
tid
=
r
/
2
;
const
int
is0
=
r
%
2
;
const
int
l0
=
16
*
is0
+
4
*
(
threadIdx
.
x
%
4
);
const
int
n
=
tid
/
4
;
const
int
j
=
tid
-
4
*
n
;
uint8_t
m
=
1
<<
(
4
*
n
+
j
);
int
is
=
8
*
n
+
2
*
j
+
is0
;
int
shift
=
2
*
j
;
int8_t
us
=
is
<
4
?
(
x
[
i
].
scales
[
is
-
0
]
&
0xF
)
|
(((
x
[
i
].
scales
[
is
+
8
]
>>
0
)
&
3
)
<<
4
)
:
is
<
8
?
(
x
[
i
].
scales
[
is
-
0
]
&
0xF
)
|
(((
x
[
i
].
scales
[
is
+
4
]
>>
2
)
&
3
)
<<
4
)
:
is
<
12
?
(
x
[
i
].
scales
[
is
-
8
]
>>
4
)
|
(((
x
[
i
].
scales
[
is
+
0
]
>>
4
)
&
3
)
<<
4
)
:
(
x
[
i
].
scales
[
is
-
8
]
>>
4
)
|
(((
x
[
i
].
scales
[
is
-
4
]
>>
6
)
&
3
)
<<
4
);
half
d_all
=
x
[
i
].
d
;
half
dl
=
__hmul
(
d_all
,
__int2half_rn
(
us
-
32
));
dst_t
*
y
=
yy
+
i
*
QK_K
+
128
*
n
+
32
*
j
;
const
uint8_t
*
q
=
x
[
i
].
qs
+
32
*
n
;
const
uint8_t
*
hm
=
x
[
i
].
hmask
;
for
(
int
l
=
l0
;
l
<
l0
+
4
;
++
l
)
y
[
l
]
=
__hmul
(
dl
,
__int2half_rn
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
)
-
((
hm
[
l
]
&
m
)
?
0
:
4
)));
}
static
inline
__device__
void
get_scale_min_k4
(
int
j
,
const
uint8_t
*
q
,
uint8_t
&
d
,
uint8_t
&
m
)
{
if
(
j
<
4
)
{
d
=
q
[
j
]
&
63
;
m
=
q
[
j
+
4
]
&
63
;
}
else
{
d
=
(
q
[
j
+
4
]
&
0xF
)
|
((
q
[
j
-
4
]
>>
6
)
<<
4
);
m
=
(
q
[
j
+
4
]
>>
4
)
|
((
q
[
j
-
0
]
>>
6
)
<<
4
);
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q4_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
block_q4_K
*
x
=
(
const
block_q4_K
*
)
vx
;
const
int
i
=
blockIdx
.
x
;
// assume 32 threads
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
const
int
ir
=
tid
%
8
;
const
int
is
=
2
*
il
;
const
int
n
=
4
;
dst_t
*
y
=
yy
+
i
*
QK_K
+
64
*
il
+
n
*
ir
;
const
half
dall
=
__low2half
(
x
[
i
].
dm
);
const
half
dmin
=
__high2half
(
x
[
i
].
dm
);
const
uint8_t
*
q
=
x
[
i
].
qs
+
32
*
il
+
n
*
ir
;
uint8_t
sc
,
m
;
get_scale_min_k4
(
is
+
0
,
x
[
i
].
scales
,
sc
,
m
);
const
half
d1
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m1
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
get_scale_min_k4
(
is
+
1
,
x
[
i
].
scales
,
sc
,
m
);
const
half
d2
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m2
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
for
(
int
l
=
0
;
l
<
n
;
++
l
)
{
y
[
l
+
0
]
=
__hsub
(
__hmul
(
d1
,
__int2half_rn
(
q
[
l
]
&
0xF
)),
m1
);
y
[
l
+
32
]
=
__hsub
(
__hmul
(
d2
,
__int2half_rn
(
q
[
l
]
>>
4
)),
m2
);
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q5_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
block_q5_K
*
x
=
(
const
block_q5_K
*
)
vx
;
const
int
i
=
blockIdx
.
x
;
// assume 64 threads - this is very slightly better than the one below
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
16
;
// il is in 0...3
const
int
ir
=
tid
%
16
;
// ir is in 0...15
const
int
is
=
2
*
il
;
// is is in 0...6
dst_t
*
y
=
yy
+
i
*
QK_K
+
64
*
il
+
2
*
ir
;
const
half
dall
=
__low2half
(
x
[
i
].
dm
);
const
half
dmin
=
__high2half
(
x
[
i
].
dm
);
const
uint8_t
*
ql
=
x
[
i
].
qs
+
32
*
il
+
2
*
ir
;
const
uint8_t
*
qh
=
x
[
i
].
qh
+
2
*
ir
;
uint8_t
sc
,
m
;
get_scale_min_k4
(
is
+
0
,
x
[
i
].
scales
,
sc
,
m
);
const
half
d1
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m1
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
get_scale_min_k4
(
is
+
1
,
x
[
i
].
scales
,
sc
,
m
);
const
half
d2
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m2
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
uint8_t
hm
=
1
<<
(
2
*
il
);
y
[
0
]
=
__hsub
(
__hmul
(
d1
,
__int2half_rn
((
ql
[
0
]
&
0xF
)
+
(
qh
[
0
]
&
hm
?
16
:
0
))),
m1
);
y
[
1
]
=
__hsub
(
__hmul
(
d1
,
__int2half_rn
((
ql
[
1
]
&
0xF
)
+
(
qh
[
1
]
&
hm
?
16
:
0
))),
m1
);
hm
<<=
1
;
y
[
32
]
=
__hsub
(
__hmul
(
d2
,
__int2half_rn
((
ql
[
0
]
>>
4
)
+
(
qh
[
0
]
&
hm
?
16
:
0
))),
m2
);
y
[
33
]
=
__hsub
(
__hmul
(
d2
,
__int2half_rn
((
ql
[
1
]
>>
4
)
+
(
qh
[
1
]
&
hm
?
16
:
0
))),
m2
);
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_q6_K
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
block_q6_K
*
x
=
(
const
block_q6_K
*
)
vx
;
const
int
i
=
blockIdx
.
x
;
// assume 64 threads - this is very slightly better than the one below
const
int
tid
=
threadIdx
.
x
;
const
int
ip
=
tid
/
32
;
// ip is 0 or 1
const
int
il
=
tid
-
32
*
ip
;
// 0...32
const
int
is
=
8
*
ip
+
il
/
16
;
dst_t
*
y
=
yy
+
i
*
QK_K
+
128
*
ip
+
il
;
const
half
d
=
x
[
i
].
d
;
const
uint8_t
*
ql
=
x
[
i
].
ql
+
64
*
ip
+
il
;
const
uint8_t
qh
=
x
[
i
].
qh
[
32
*
ip
+
il
];
const
int8_t
*
sc
=
x
[
i
].
scales
+
is
;
y
[
0
]
=
__hmul
(
d
,
__int2half_rn
(
sc
[
0
]
*
((
int8_t
)((
ql
[
0
]
&
0xF
)
|
(((
qh
>>
0
)
&
3
)
<<
4
))
-
32
)));
y
[
32
]
=
__hmul
(
d
,
__int2half_rn
(
sc
[
2
]
*
((
int8_t
)((
ql
[
32
]
&
0xF
)
|
(((
qh
>>
2
)
&
3
)
<<
4
))
-
32
)));
y
[
64
]
=
__hmul
(
d
,
__int2half_rn
(
sc
[
4
]
*
((
int8_t
)((
ql
[
0
]
>>
4
)
|
(((
qh
>>
4
)
&
3
)
<<
4
))
-
32
)));
y
[
96
]
=
__hmul
(
d
,
__int2half_rn
(
sc
[
6
]
*
((
int8_t
)((
ql
[
32
]
>>
4
)
|
(((
qh
>>
6
)
&
3
)
<<
4
))
-
32
)));
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq2_xxs
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_iq2_xxs
*
x
=
(
const
block_iq2_xxs
*
)
vx
;
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint16_t
*
q2
=
x
[
i
].
qs
+
4
*
ib
;
const
uint8_t
*
aux8
=
(
const
uint8_t
*
)
q2
;
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xxs_grid
+
aux8
[
il
]);
const
uint32_t
aux32
=
q2
[
2
]
|
(
q2
[
3
]
<<
16
);
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
(
aux32
>>
28
))
*
0.25
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[(
aux32
>>
7
*
il
)
&
127
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
__float2half
(
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
));
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq2_xs
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_iq2_xs
*
x
=
(
const
block_iq2_xs
*
)
vx
;
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint16_t
*
q2
=
x
[
i
].
qs
+
4
*
ib
;
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xs_grid
+
(
q2
[
il
]
&
511
));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
]
>>
4
*
(
il
/
2
))
&
0xf
))
*
0.25
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[
q2
[
il
]
>>
9
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
__float2half
(
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
));
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq2_s
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_iq2_s
*
x
=
(
const
block_iq2_s
*
)
vx
;
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2s_grid
+
(
x
[
i
].
qs
[
4
*
ib
+
il
]
|
((
x
[
i
].
qh
[
ib
]
<<
(
8
-
2
*
il
))
&
0x300
)));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
]
>>
4
*
(
il
/
2
))
&
0xf
))
*
0.25
f
;
const
uint8_t
signs
=
x
[
i
].
qs
[
QK_K
/
8
+
4
*
ib
+
il
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
__float2half
(
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
));
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq3_xxs
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_iq3_xxs
*
x
=
(
const
block_iq3_xxs
*
)
vx
;
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint8_t
*
q3
=
x
[
i
].
qs
+
8
*
ib
;
const
uint16_t
*
gas
=
(
const
uint16_t
*
)(
x
[
i
].
qs
+
QK_K
/
4
)
+
2
*
ib
;
const
uint8_t
*
grid1
=
(
const
uint8_t
*
)(
iq3xxs_grid
+
q3
[
2
*
il
+
0
]);
const
uint8_t
*
grid2
=
(
const
uint8_t
*
)(
iq3xxs_grid
+
q3
[
2
*
il
+
1
]);
const
uint32_t
aux32
=
gas
[
0
]
|
(
gas
[
1
]
<<
16
);
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
(
aux32
>>
28
))
*
0.5
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[(
aux32
>>
7
*
il
)
&
127
];
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
__float2half
(
d
*
grid1
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
0
]
?
-
1.
f
:
1.
f
));
y
[
j
+
4
]
=
__float2half
(
d
*
grid2
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
4
]
?
-
1.
f
:
1.
f
));
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq3_s
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_iq3_s
*
x
=
(
const
block_iq3_s
*
)
vx
;
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint8_t
*
qs
=
x
[
i
].
qs
+
8
*
ib
;
const
uint8_t
*
grid1
=
(
const
uint8_t
*
)(
iq3xs_grid
+
(
qs
[
2
*
il
+
0
]
|
((
x
[
i
].
qh
[
ib
]
<<
(
8
-
2
*
il
))
&
256
)));
const
uint8_t
*
grid2
=
(
const
uint8_t
*
)(
iq3xs_grid
+
(
qs
[
2
*
il
+
1
]
|
((
x
[
i
].
qh
[
ib
]
<<
(
7
-
2
*
il
))
&
256
)));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
))
*
0.5
f
;
const
uint8_t
signs
=
x
[
i
].
signs
[
4
*
ib
+
il
];
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
__float2half
(
d
*
grid1
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
0
]
?
-
1.
f
:
1.
f
));
y
[
j
+
4
]
=
__float2half
(
d
*
grid2
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
4
]
?
-
1.
f
:
1.
f
));
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq1_s
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int64_t
i
=
blockIdx
.
x
;
const
block_iq1_s
*
x
=
(
const
block_iq1_s
*
)
vx
;
const
int64_t
tid
=
threadIdx
.
x
;
const
int64_t
il
=
tid
/
8
;
// 0...3
const
int64_t
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
float
delta
=
x
[
i
].
qh
[
ib
]
&
0x8000
?
-
1
-
IQ1S_DELTA
:
-
1
+
IQ1S_DELTA
;
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
2
*
((
x
[
i
].
qh
[
ib
]
>>
12
)
&
7
)
+
1
);
uint32_t
grid32
[
2
];
const
int8_t
*
q
=
(
const
int8_t
*
)
grid32
;
grid32
[
0
]
=
iq1s_grid_gpu
[
x
[
i
].
qs
[
4
*
ib
+
il
]
|
(((
x
[
i
].
qh
[
ib
]
>>
3
*
il
)
&
7
)
<<
8
)];
grid32
[
1
]
=
(
grid32
[
0
]
>>
4
)
&
0x0f0f0f0f
;
grid32
[
0
]
&=
0x0f0f0f0f
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
y
[
j
]
=
__float2half
(
d
*
(
q
[
j
]
+
delta
));
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq1_m
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int64_t
i
=
blockIdx
.
x
;
const
block_iq1_m
*
x
=
(
const
block_iq1_m
*
)
vx
;
const
int64_t
tid
=
threadIdx
.
x
;
const
int64_t
il
=
tid
/
8
;
// 0...3
const
int64_t
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
8
*
il
;
const
uint16_t
*
sc
=
(
const
uint16_t
*
)
x
[
i
].
scales
;
iq1m_scale_t
scale
;
scale
.
u16
=
(
sc
[
0
]
>>
12
)
|
((
sc
[
1
]
>>
8
)
&
0x00f0
)
|
((
sc
[
2
]
>>
4
)
&
0x0f00
)
|
(
sc
[
3
]
&
0xf000
);
const
int64_t
ib16
=
2
*
ib
+
il
/
2
;
// sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
const
float
d
=
__half2float
(
scale
.
f16
)
*
(
2
*
((
sc
[
ib16
/
4
]
>>
3
*
(
ib16
%
4
))
&
0x7
)
+
1
);
const
float
delta
=
x
[
i
].
qh
[
2
*
ib
+
il
/
2
]
&
(
0x08
<<
4
*
(
il
%
2
))
?
-
1
-
IQ1M_DELTA
:
-
1
+
IQ1M_DELTA
;
uint32_t
grid32
[
2
];
const
int8_t
*
q
=
(
const
int8_t
*
)
grid32
;
grid32
[
0
]
=
iq1s_grid_gpu
[
x
[
i
].
qs
[
4
*
ib
+
il
]
|
(((
x
[
i
].
qh
[
2
*
ib
+
il
/
2
]
>>
4
*
(
il
%
2
))
&
7
)
<<
8
)];
grid32
[
1
]
=
(
grid32
[
0
]
>>
4
)
&
0x0f0f0f0f
;
grid32
[
0
]
&=
0x0f0f0f0f
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
y
[
j
]
=
__float2half
(
d
*
(
q
[
j
]
+
delta
));
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq4_nl
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_iq4_nl
*
x
=
(
const
block_iq4_nl
*
)
vx
+
i
*
(
QK_K
/
QK4_NL
);
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
4
*
il
;
const
uint8_t
*
q4
=
x
[
ib
].
qs
+
4
*
il
;
const
float
d
=
__half2float
(
x
[
ib
].
d
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
__float2half
(
d
*
kvalues_iq4nl
[
q4
[
j
]
&
0xf
]);
y
[
j
+
16
]
=
__float2half
(
d
*
kvalues_iq4nl
[
q4
[
j
]
>>
4
]);
}
}
template
<
typename
dst_t
>
static
__global__
void
dequantize_block_iq4_xs
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
yy
)
{
const
int
i
=
blockIdx
.
x
;
const
block_iq4_xs
*
x
=
(
const
block_iq4_xs
*
)
vx
;
const
int
tid
=
threadIdx
.
x
;
const
int
il
=
tid
/
8
;
// 0...3
const
int
ib
=
tid
%
8
;
// 0...7
dst_t
*
y
=
yy
+
i
*
QK_K
+
32
*
ib
+
4
*
il
;
const
uint8_t
*
q4
=
x
[
i
].
qs
+
16
*
ib
+
4
*
il
;
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
((((
x
[
i
].
scales_l
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
)
|
(((
x
[
i
].
scales_h
>>
2
*
ib
)
&
3
)
<<
4
))
-
32
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
__float2half
(
d
*
kvalues_iq4nl
[
q4
[
j
]
&
0xf
]);
y
[
j
+
16
]
=
__float2half
(
d
*
kvalues_iq4nl
[
q4
[
j
]
>>
4
]);
}
}
template
<
int
qk
,
int
qr
,
dequantize_kernel_t
dequantize_kernel
,
typename
dst_t
>
static
void
dequantize_block_cuda
(
const
void
*
__restrict__
vx
,
dst_t
*
__restrict__
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
(
k
+
2
*
CUDA_DEQUANTIZE_BLOCK_SIZE
-
1
)
/
(
2
*
CUDA_DEQUANTIZE_BLOCK_SIZE
);
dequantize_block
<
qk
,
qr
,
dequantize_kernel
><<<
num_blocks
,
CUDA_DEQUANTIZE_BLOCK_SIZE
,
0
,
stream
>>>
(
vx
,
y
,
k
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q2_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q2_K
<<<
nb
,
64
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q3_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q3_K
<<<
nb
,
64
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q4_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q4_K
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q5_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q5_K
<<<
nb
,
64
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_q6_K_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_q6_K
<<<
nb
,
64
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq2_xxs_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq2_xxs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq2_xs_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq2_xs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq2_s_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq2_s
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq3_xxs_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq3_xxs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq3_s_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq3_s
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq1_s_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq1_s
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq1_m_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
k
/
QK_K
;
dequantize_block_iq1_m
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq4_nl_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
(
k
+
QK_K
-
1
)
/
QK_K
;
dequantize_block_iq4_nl
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
template
<
typename
dst_t
>
static
void
dequantize_row_iq4_xs_cuda
(
const
void
*
vx
,
dst_t
*
y
,
const
int
k
,
cudaStream_t
stream
)
{
const
int
nb
=
(
k
+
QK_K
-
1
)
/
QK_K
;
dequantize_block_iq4_xs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
static
to_fp16_cuda_t
ggml_get_to_fp16_cuda
(
int64_t
type
)
{
switch
(
type
)
{
case
2
:
return
dequantize_block_cuda
<
QK4_0
,
QR4_0
,
dequantize_q4_0
>
;
case
3
:
return
dequantize_block_cuda
<
QK4_1
,
QR4_1
,
dequantize_q4_1
>
;
case
6
:
return
dequantize_block_cuda
<
QK5_0
,
QR5_0
,
dequantize_q5_0
>
;
case
7
:
return
dequantize_block_cuda
<
QK5_1
,
QR5_1
,
dequantize_q5_1
>
;
case
8
:
return
dequantize_block_cuda
<
QK8_0
,
QR8_0
,
dequantize_q8_0
>
;
case
10
:
return
dequantize_row_q2_K_cuda
;
case
11
:
return
dequantize_row_q3_K_cuda
;
case
12
:
return
dequantize_row_q4_K_cuda
;
case
13
:
return
dequantize_row_q5_K_cuda
;
case
14
:
return
dequantize_row_q6_K_cuda
;
case
16
:
return
dequantize_row_iq2_xxs_cuda
;
case
17
:
return
dequantize_row_iq2_xs_cuda
;
case
18
:
return
dequantize_row_iq3_xxs_cuda
;
case
19
:
return
dequantize_row_iq1_s_cuda
;
case
20
:
return
dequantize_row_iq4_nl_cuda
;
case
21
:
return
dequantize_row_iq3_s_cuda
;
case
22
:
return
dequantize_row_iq2_s_cuda
;
case
23
:
return
dequantize_row_iq4_xs_cuda
;
case
29
:
return
dequantize_row_iq1_m_cuda
;
default:
return
nullptr
;
}
}
\ No newline at end of file
csrc/quantization/gguf/ggml-common.h
0 → 100644
View file @
ad385667
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE 32
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define GGML_CUDA_DMMV_X 32
#define GGML_CUDA_MMV_Y 1
// Data Structures
// QK = number of values after dequantization
// QR = QK / number of values before dequantization
// QI = number of 32 bit integers before dequantization
#define QK4_0 32
#define QR4_0 2
#define QI4_0 (QK4_0 / (4 * QR4_0))
typedef
struct
{
half
d
;
// delta
uint8_t
qs
[
QK4_0
/
2
];
// nibbles / quants
}
block_q4_0
;
#define QK4_1 32
#define QR4_1 2
#define QI4_1 (QK4_1 / (4 * QR4_1))
typedef
struct
{
half2
dm
;
// dm.x = delta, dm.y = min
uint8_t
qs
[
QK4_1
/
2
];
// nibbles / quants
}
block_q4_1
;
#define QK5_0 32
#define QR5_0 2
#define QI5_0 (QK5_0 / (4 * QR5_0))
typedef
struct
{
half
d
;
// delta
uint8_t
qh
[
4
];
// 5-th bit of quants
uint8_t
qs
[
QK5_0
/
2
];
// nibbles / quants
}
block_q5_0
;
#define QK5_1 32
#define QR5_1 2
#define QI5_1 (QK5_1 / (4 * QR5_1))
typedef
struct
{
half2
dm
;
// dm.x = delta, dm.y = min
uint8_t
qh
[
4
];
// 5-th bit of quants
uint8_t
qs
[
QK5_1
/
2
];
// nibbles / quants
}
block_q5_1
;
#define QK8_0 32
#define QR8_0 1
#define QI8_0 (QK8_0 / (4 * QR8_0))
typedef
struct
{
half
d
;
// delta
int8_t
qs
[
QK8_0
];
// quants
}
block_q8_0
;
#define QK8_1 32
#define QR8_1 1
#define QI8_1 (QK8_1 / (4 * QR8_1))
typedef
struct
{
half2
ds
;
// ds.x = delta, ds.y = sum
int8_t
qs
[
QK8_0
];
// quants
}
block_q8_1
;
#define QR2_K 4
#define QI2_K (QK_K / (4*QR2_K))
typedef
struct
{
uint8_t
scales
[
QK_K
/
16
];
// scales and mins, quantized with 4 bits
uint8_t
qs
[
QK_K
/
4
];
// quants
half2
dm
;
// super-block scale for quantized scales/mins
}
block_q2_K
;
#define QR3_K 4
#define QI3_K (QK_K / (4*QR3_K))
typedef
struct
{
uint8_t
hmask
[
QK_K
/
8
];
// quants - high bit
uint8_t
qs
[
QK_K
/
4
];
// quants - low 2 bits
uint8_t
scales
[
K_SCALE_SIZE
];
// scales, quantized with 6 bits
half
d
;
// super-block scale
}
block_q3_K
;
#define QR4_K 2
#define QI4_K (QK_K / (4*QR4_K))
typedef
struct
{
half2
dm
;
// super-block scale for quantized scales/mins
uint8_t
scales
[
3
*
QK_K
/
64
];
// scales, quantized with 6 bits
uint8_t
qs
[
QK_K
/
2
];
// 4--bit quants
}
block_q4_K
;
#define QR5_K 2
#define QI5_K (QK_K / (4*QR5_K))
typedef
struct
{
half2
dm
;
// super-block scale for quantized scales/mins
uint8_t
scales
[
K_SCALE_SIZE
];
// scales and mins, quantized with 6 bits
uint8_t
qh
[
QK_K
/
8
];
// quants, high bit
uint8_t
qs
[
QK_K
/
2
];
// quants, low 4 bits
}
block_q5_K
;
#define QR6_K 2
#define QI6_K (QK_K / (4*QR6_K))
typedef
struct
{
uint8_t
ql
[
QK_K
/
2
];
// quants, lower 4 bits
uint8_t
qh
[
QK_K
/
4
];
// quants, upper 2 bits
int8_t
scales
[
QK_K
/
16
];
// scales
half
d
;
// delta
}
block_q6_K
;
#define QR2_XXS 8
#define QI2_XXS (QK_K / (4*QR2_XXS))
typedef
struct
{
half
d
;
uint16_t
qs
[
QK_K
/
8
];
}
block_iq2_xxs
;
#define QR2_XS 8
#define QI2_XS (QK_K / (4*QR2_XS))
typedef
struct
{
half
d
;
uint16_t
qs
[
QK_K
/
8
];
uint8_t
scales
[
QK_K
/
32
];
}
block_iq2_xs
;
#define QR2_S 8
#define QI2_S (QK_K / (4*QR2_S))
typedef
struct
{
half
d
;
uint8_t
qs
[
QK_K
/
4
];
uint8_t
qh
[
QK_K
/
32
];
uint8_t
scales
[
QK_K
/
32
];
}
block_iq2_s
;
#define QR3_XXS 8
#define QI3_XXS (QK_K / (4*QR3_XXS))
typedef
struct
{
half
d
;
uint8_t
qs
[
3
*
(
QK_K
/
8
)];
}
block_iq3_xxs
;
#define QR3_XS 8
#define QI3_XS (QK_K / (4*QR3_XS))
#define IQ3S_N_SCALE QK_K/64
typedef
struct
{
half
d
;
uint8_t
qs
[
QK_K
/
4
];
uint8_t
qh
[
QK_K
/
32
];
uint8_t
signs
[
QK_K
/
8
];
uint8_t
scales
[
IQ3S_N_SCALE
];
}
block_iq3_s
;
// 1.5625 bpw
#define QR1_S 8
#define QI1_S (QK_K / (4*QR1_S))
typedef
struct
{
half
d
;
uint8_t
qs
[
QK_K
/
8
];
uint16_t
qh
[
QK_K
/
32
];
}
block_iq1_s
;
// 1.75 bpw
#define QR1_M 8
#define QI1_M (QK_K / (4*QR1_M))
typedef
struct
{
uint8_t
qs
[
QK_K
/
8
];
// grid index, low 8 bits
uint8_t
qh
[
QK_K
/
16
];
// grid index, high 3 bits + grid shift bit (for two groups of 8)
uint8_t
scales
[
QK_K
/
32
];
// 3-bit block scales (4-bit if QK_K == 64)
}
block_iq1_m
;
// Used by IQ1_M quants
typedef
union
{
half
f16
;
uint16_t
u16
;
}
iq1m_scale_t
;
#define QK4_NL 32
#define QR4_NL 2
#define QI4_NL (QK4_NL / (4*QR4_NL))
typedef
struct
{
half
d
;
uint8_t
qs
[
QK4_NL
/
2
];
}
block_iq4_nl
;
#define QR4_XS 8
#define QI4_XS (QK_K / (4*QR4_XS))
typedef
struct
{
half
d
;
uint16_t
scales_h
;
uint8_t
scales_l
[
QK_K
/
64
];
uint8_t
qs
[
QK_K
/
2
];
}
block_iq4_xs
;
static
const
__device__
uint64_t
iq2xxs_grid
[
256
]
=
{
0x0808080808080808
,
0x080808080808082b
,
0x0808080808081919
,
0x0808080808082b08
,
0x0808080808082b2b
,
0x0808080808190819
,
0x0808080808191908
,
0x08080808082b0808
,
0x08080808082b082b
,
0x08080808082b2b08
,
0x08080808082b2b2b
,
0x0808080819080819
,
0x0808080819081908
,
0x0808080819190808
,
0x0808080819192b08
,
0x08080808192b0819
,
0x08080808192b1908
,
0x080808082b080808
,
0x080808082b08082b
,
0x080808082b082b2b
,
0x080808082b2b082b
,
0x0808081908080819
,
0x0808081908081908
,
0x0808081908190808
,
0x0808081908191919
,
0x0808081919080808
,
0x080808192b081908
,
0x080808192b192b08
,
0x0808082b08080808
,
0x0808082b0808082b
,
0x0808082b082b082b
,
0x0808082b2b08082b
,
0x0808190808080819
,
0x0808190808081908
,
0x0808190808190808
,
0x08081908082b0819
,
0x08081908082b1908
,
0x0808190819080808
,
0x080819081908082b
,
0x0808190819082b08
,
0x08081908192b0808
,
0x080819082b080819
,
0x080819082b081908
,
0x080819082b190808
,
0x080819082b2b1908
,
0x0808191908080808
,
0x080819190808082b
,
0x0808191908082b08
,
0x08081919082b0808
,
0x080819191908192b
,
0x08081919192b2b19
,
0x080819192b080808
,
0x080819192b190819
,
0x0808192b08082b19
,
0x0808192b08190808
,
0x0808192b19080808
,
0x0808192b2b081908
,
0x0808192b2b2b1908
,
0x08082b0808080808
,
0x08082b0808081919
,
0x08082b0808082b08
,
0x08082b0808191908
,
0x08082b08082b2b08
,
0x08082b0819080819
,
0x08082b0819081908
,
0x08082b0819190808
,
0x08082b081919082b
,
0x08082b082b082b08
,
0x08082b1908081908
,
0x08082b1919080808
,
0x08082b2b0808082b
,
0x08082b2b08191908
,
0x0819080808080819
,
0x0819080808081908
,
0x0819080808190808
,
0x08190808082b0819
,
0x0819080819080808
,
0x08190808192b0808
,
0x081908082b081908
,
0x081908082b190808
,
0x081908082b191919
,
0x0819081908080808
,
0x0819081908082b08
,
0x08190819082b0808
,
0x0819081919190808
,
0x0819081919192b2b
,
0x081908192b080808
,
0x0819082b082b1908
,
0x0819082b19081919
,
0x0819190808080808
,
0x0819190808082b08
,
0x08191908082b0808
,
0x08191908082b1919
,
0x0819190819082b19
,
0x081919082b080808
,
0x0819191908192b08
,
0x08191919192b082b
,
0x0819192b08080808
,
0x0819192b0819192b
,
0x08192b0808080819
,
0x08192b0808081908
,
0x08192b0808190808
,
0x08192b0819080808
,
0x08192b082b080819
,
0x08192b1908080808
,
0x08192b1908081919
,
0x08192b192b2b0808
,
0x08192b2b19190819
,
0x082b080808080808
,
0x082b08080808082b
,
0x082b080808082b2b
,
0x082b080819081908
,
0x082b0808192b0819
,
0x082b08082b080808
,
0x082b08082b08082b
,
0x082b0819082b2b19
,
0x082b081919082b08
,
0x082b082b08080808
,
0x082b082b0808082b
,
0x082b190808080819
,
0x082b190808081908
,
0x082b190808190808
,
0x082b190819080808
,
0x082b19081919192b
,
0x082b191908080808
,
0x082b191919080819
,
0x082b1919192b1908
,
0x082b192b2b190808
,
0x082b2b0808082b08
,
0x082b2b08082b0808
,
0x082b2b082b191908
,
0x082b2b2b19081908
,
0x1908080808080819
,
0x1908080808081908
,
0x1908080808190808
,
0x1908080808192b08
,
0x19080808082b0819
,
0x19080808082b1908
,
0x1908080819080808
,
0x1908080819082b08
,
0x190808081919192b
,
0x19080808192b0808
,
0x190808082b080819
,
0x190808082b081908
,
0x190808082b190808
,
0x1908081908080808
,
0x19080819082b0808
,
0x19080819192b0819
,
0x190808192b080808
,
0x190808192b081919
,
0x1908082b08080819
,
0x1908082b08190808
,
0x1908082b19082b08
,
0x1908082b1919192b
,
0x1908082b192b2b08
,
0x1908190808080808
,
0x1908190808082b08
,
0x19081908082b0808
,
0x190819082b080808
,
0x190819082b192b19
,
0x190819190819082b
,
0x19081919082b1908
,
0x1908192b08080808
,
0x19082b0808080819
,
0x19082b0808081908
,
0x19082b0808190808
,
0x19082b0819080808
,
0x19082b0819081919
,
0x19082b1908080808
,
0x19082b1919192b08
,
0x19082b19192b0819
,
0x19082b192b08082b
,
0x19082b2b19081919
,
0x19082b2b2b190808
,
0x1919080808080808
,
0x1919080808082b08
,
0x1919080808190819
,
0x1919080808192b19
,
0x19190808082b0808
,
0x191908082b080808
,
0x191908082b082b08
,
0x1919081908081908
,
0x191908191908082b
,
0x191908192b2b1908
,
0x1919082b2b190819
,
0x191919082b190808
,
0x191919082b19082b
,
0x1919191908082b2b
,
0x1919192b08080819
,
0x1919192b19191908
,
0x19192b0808080808
,
0x19192b0808190819
,
0x19192b0808192b19
,
0x19192b08192b1908
,
0x19192b1919080808
,
0x19192b2b08082b08
,
0x192b080808081908
,
0x192b080808190808
,
0x192b080819080808
,
0x192b0808192b2b08
,
0x192b081908080808
,
0x192b081919191919
,
0x192b082b08192b08
,
0x192b082b192b0808
,
0x192b190808080808
,
0x192b190808081919
,
0x192b191908190808
,
0x192b19190819082b
,
0x192b19192b081908
,
0x192b2b081908082b
,
0x2b08080808080808
,
0x2b0808080808082b
,
0x2b08080808082b2b
,
0x2b08080819080819
,
0x2b0808082b08082b
,
0x2b08081908081908
,
0x2b08081908192b08
,
0x2b08081919080808
,
0x2b08082b08190819
,
0x2b08190808080819
,
0x2b08190808081908
,
0x2b08190808190808
,
0x2b08190808191919
,
0x2b08190819080808
,
0x2b081908192b0808
,
0x2b08191908080808
,
0x2b0819191908192b
,
0x2b0819192b191908
,
0x2b08192b08082b19
,
0x2b08192b19080808
,
0x2b08192b192b0808
,
0x2b082b080808082b
,
0x2b082b1908081908
,
0x2b082b2b08190819
,
0x2b19080808081908
,
0x2b19080808190808
,
0x2b190808082b1908
,
0x2b19080819080808
,
0x2b1908082b2b0819
,
0x2b1908190819192b
,
0x2b1908192b080808
,
0x2b19082b19081919
,
0x2b19190808080808
,
0x2b191908082b082b
,
0x2b19190819081908
,
0x2b19191919190819
,
0x2b192b082b080819
,
0x2b192b19082b0808
,
0x2b2b08080808082b
,
0x2b2b080819190808
,
0x2b2b08082b081919
,
0x2b2b081908082b19
,
0x2b2b082b08080808
,
0x2b2b190808192b08
,
0x2b2b2b0819190808
,
0x2b2b2b1908081908
,
};
static
const
__device__
uint64_t
iq2xs_grid
[
512
]
=
{
0x0808080808080808
,
0x080808080808082b
,
0x0808080808081919
,
0x0808080808082b08
,
0x0808080808082b2b
,
0x0808080808190819
,
0x0808080808191908
,
0x080808080819192b
,
0x0808080808192b19
,
0x08080808082b0808
,
0x08080808082b082b
,
0x08080808082b1919
,
0x08080808082b2b08
,
0x0808080819080819
,
0x0808080819081908
,
0x080808081908192b
,
0x0808080819082b19
,
0x0808080819190808
,
0x080808081919082b
,
0x0808080819191919
,
0x0808080819192b08
,
0x08080808192b0819
,
0x08080808192b1908
,
0x080808082b080808
,
0x080808082b08082b
,
0x080808082b081919
,
0x080808082b082b08
,
0x080808082b190819
,
0x080808082b191908
,
0x080808082b192b19
,
0x080808082b2b0808
,
0x0808081908080819
,
0x0808081908081908
,
0x080808190808192b
,
0x0808081908082b19
,
0x0808081908190808
,
0x080808190819082b
,
0x0808081908191919
,
0x0808081908192b08
,
0x0808081908192b2b
,
0x08080819082b0819
,
0x08080819082b1908
,
0x0808081919080808
,
0x080808191908082b
,
0x0808081919081919
,
0x0808081919082b08
,
0x0808081919190819
,
0x0808081919191908
,
0x08080819192b0808
,
0x08080819192b2b08
,
0x080808192b080819
,
0x080808192b081908
,
0x080808192b190808
,
0x0808082b08080808
,
0x0808082b0808082b
,
0x0808082b08081919
,
0x0808082b08082b08
,
0x0808082b08190819
,
0x0808082b08191908
,
0x0808082b082b0808
,
0x0808082b19080819
,
0x0808082b19081908
,
0x0808082b19190808
,
0x0808082b19191919
,
0x0808082b2b080808
,
0x0808082b2b082b2b
,
0x0808190808080819
,
0x0808190808081908
,
0x080819080808192b
,
0x0808190808082b19
,
0x0808190808190808
,
0x080819080819082b
,
0x0808190808191919
,
0x0808190808192b08
,
0x08081908082b0819
,
0x08081908082b1908
,
0x0808190819080808
,
0x080819081908082b
,
0x0808190819081919
,
0x0808190819082b08
,
0x0808190819190819
,
0x0808190819191908
,
0x080819081919192b
,
0x08081908192b0808
,
0x080819082b080819
,
0x080819082b081908
,
0x080819082b190808
,
0x0808191908080808
,
0x080819190808082b
,
0x0808191908081919
,
0x0808191908082b08
,
0x0808191908190819
,
0x0808191908191908
,
0x08081919082b0808
,
0x0808191919080819
,
0x0808191919081908
,
0x0808191919190808
,
0x08081919192b0819
,
0x080819192b080808
,
0x0808192b08080819
,
0x0808192b08081908
,
0x0808192b08190808
,
0x0808192b082b192b
,
0x0808192b19080808
,
0x0808192b1908082b
,
0x0808192b2b081908
,
0x08082b0808080808
,
0x08082b080808082b
,
0x08082b0808081919
,
0x08082b0808082b08
,
0x08082b0808082b2b
,
0x08082b0808190819
,
0x08082b0808191908
,
0x08082b08082b0808
,
0x08082b08082b1919
,
0x08082b0819080819
,
0x08082b0819081908
,
0x08082b0819190808
,
0x08082b0819192b08
,
0x08082b082b080808
,
0x08082b082b2b0808
,
0x08082b082b2b2b2b
,
0x08082b1908080819
,
0x08082b1908081908
,
0x08082b1908190808
,
0x08082b1919080808
,
0x08082b192b080819
,
0x08082b192b082b19
,
0x08082b2b08080808
,
0x08082b2b082b0808
,
0x08082b2b082b2b08
,
0x08082b2b2b19192b
,
0x08082b2b2b2b0808
,
0x0819080808080819
,
0x0819080808081908
,
0x081908080808192b
,
0x0819080808082b19
,
0x0819080808190808
,
0x081908080819082b
,
0x0819080808191919
,
0x0819080808192b08
,
0x08190808082b0819
,
0x08190808082b1908
,
0x0819080819080808
,
0x081908081908082b
,
0x0819080819081919
,
0x0819080819082b08
,
0x0819080819190819
,
0x0819080819191908
,
0x08190808192b0808
,
0x08190808192b2b2b
,
0x081908082b080819
,
0x081908082b081908
,
0x081908082b190808
,
0x0819081908080808
,
0x081908190808082b
,
0x0819081908081919
,
0x0819081908082b08
,
0x0819081908190819
,
0x0819081908191908
,
0x08190819082b0808
,
0x0819081919080819
,
0x0819081919081908
,
0x0819081919190808
,
0x081908192b080808
,
0x081908192b191908
,
0x081908192b19192b
,
0x0819082b08080819
,
0x0819082b08081908
,
0x0819082b0808192b
,
0x0819082b08190808
,
0x0819082b19080808
,
0x0819082b192b0808
,
0x0819190808080808
,
0x081919080808082b
,
0x0819190808081919
,
0x0819190808082b08
,
0x0819190808190819
,
0x0819190808191908
,
0x08191908082b0808
,
0x0819190819080819
,
0x0819190819081908
,
0x0819190819082b19
,
0x0819190819190808
,
0x08191908192b1908
,
0x081919082b080808
,
0x0819191908080819
,
0x0819191908081908
,
0x0819191908190808
,
0x0819191919080808
,
0x0819192b08080808
,
0x0819192b08191908
,
0x0819192b19082b19
,
0x08192b0808080819
,
0x08192b0808081908
,
0x08192b0808190808
,
0x08192b080819082b
,
0x08192b0819080808
,
0x08192b0819191908
,
0x08192b082b08192b
,
0x08192b1908080808
,
0x08192b1908081919
,
0x08192b19192b192b
,
0x08192b2b19190819
,
0x08192b2b2b2b2b19
,
0x082b080808080808
,
0x082b08080808082b
,
0x082b080808081919
,
0x082b080808082b08
,
0x082b080808082b2b
,
0x082b080808190819
,
0x082b080808191908
,
0x082b0808082b0808
,
0x082b080819080819
,
0x082b080819081908
,
0x082b080819190808
,
0x082b08082b080808
,
0x082b08082b2b0808
,
0x082b081908080819
,
0x082b081908081908
,
0x082b081908190808
,
0x082b081919080808
,
0x082b081919082b08
,
0x082b0819192b1919
,
0x082b082b08080808
,
0x082b082b082b082b
,
0x082b082b2b080808
,
0x082b082b2b2b2b08
,
0x082b190808080819
,
0x082b190808081908
,
0x082b190808190808
,
0x082b1908082b2b19
,
0x082b190819080808
,
0x082b191908080808
,
0x082b191919080819
,
0x082b19191919082b
,
0x082b19192b192b19
,
0x082b192b08080819
,
0x082b192b08192b2b
,
0x082b192b2b2b192b
,
0x082b2b0808080808
,
0x082b2b0808082b08
,
0x082b2b0808082b2b
,
0x082b2b08082b0808
,
0x082b2b0819191919
,
0x082b2b082b082b08
,
0x082b2b082b2b082b
,
0x082b2b19192b2b08
,
0x082b2b192b190808
,
0x082b2b2b08082b08
,
0x082b2b2b082b0808
,
0x082b2b2b2b08082b
,
0x082b2b2b2b082b08
,
0x082b2b2b2b082b2b
,
0x1908080808080819
,
0x1908080808081908
,
0x190808080808192b
,
0x1908080808082b19
,
0x1908080808190808
,
0x190808080819082b
,
0x1908080808191919
,
0x1908080808192b08
,
0x19080808082b0819
,
0x19080808082b1908
,
0x1908080819080808
,
0x190808081908082b
,
0x1908080819081919
,
0x1908080819082b08
,
0x1908080819082b2b
,
0x1908080819190819
,
0x1908080819191908
,
0x19080808192b0808
,
0x19080808192b1919
,
0x190808082b080819
,
0x190808082b081908
,
0x190808082b190808
,
0x1908081908080808
,
0x190808190808082b
,
0x1908081908081919
,
0x1908081908082b08
,
0x1908081908190819
,
0x1908081908191908
,
0x19080819082b0808
,
0x1908081919080819
,
0x1908081919081908
,
0x1908081919190808
,
0x190808192b080808
,
0x190808192b081919
,
0x190808192b2b082b
,
0x1908082b08080819
,
0x1908082b08081908
,
0x1908082b08190808
,
0x1908082b0819082b
,
0x1908082b082b2b19
,
0x1908082b19080808
,
0x1908190808080808
,
0x190819080808082b
,
0x1908190808081919
,
0x1908190808082b08
,
0x1908190808190819
,
0x1908190808191908
,
0x1908190808192b19
,
0x19081908082b0808
,
0x1908190819080819
,
0x1908190819081908
,
0x1908190819190808
,
0x190819082b080808
,
0x190819082b191908
,
0x1908191908080819
,
0x1908191908081908
,
0x1908191908190808
,
0x19081919082b1908
,
0x1908191919080808
,
0x190819192b192b2b
,
0x1908192b08080808
,
0x1908192b08082b2b
,
0x1908192b19081908
,
0x1908192b19190808
,
0x19082b0808080819
,
0x19082b0808081908
,
0x19082b0808190808
,
0x19082b0819080808
,
0x19082b0819081919
,
0x19082b0819191908
,
0x19082b08192b082b
,
0x19082b1908080808
,
0x19082b1908190819
,
0x19082b1919081908
,
0x19082b1919190808
,
0x19082b19192b2b19
,
0x19082b2b08081908
,
0x1919080808080808
,
0x191908080808082b
,
0x1919080808081919
,
0x1919080808082b08
,
0x1919080808190819
,
0x1919080808191908
,
0x19190808082b0808
,
0x19190808082b2b08
,
0x1919080819080819
,
0x1919080819081908
,
0x1919080819190808
,
0x191908082b080808
,
0x1919081908080819
,
0x1919081908081908
,
0x1919081908190808
,
0x1919081908191919
,
0x1919081919080808
,
0x191908191908082b
,
0x1919082b08080808
,
0x1919082b19081908
,
0x1919082b2b2b2b2b
,
0x1919190808080819
,
0x1919190808081908
,
0x1919190808190808
,
0x19191908082b0819
,
0x1919190819080808
,
0x19191908192b0808
,
0x191919082b080819
,
0x191919082b2b0819
,
0x1919191908080808
,
0x1919191908082b08
,
0x191919192b080808
,
0x191919192b082b08
,
0x1919192b082b0819
,
0x1919192b192b2b08
,
0x1919192b2b2b0819
,
0x19192b0808080808
,
0x19192b0808191908
,
0x19192b0819080819
,
0x19192b0819190808
,
0x19192b082b192b19
,
0x19192b1908192b2b
,
0x19192b1919080808
,
0x19192b191908082b
,
0x19192b2b2b081919
,
0x192b080808080819
,
0x192b080808081908
,
0x192b080808190808
,
0x192b080819080808
,
0x192b080819191908
,
0x192b0808192b082b
,
0x192b08082b08192b
,
0x192b08082b2b2b19
,
0x192b081908080808
,
0x192b082b082b1908
,
0x192b082b19082b2b
,
0x192b082b2b19082b
,
0x192b190808080808
,
0x192b19080819192b
,
0x192b191908190808
,
0x192b191919080808
,
0x192b191919081919
,
0x192b19192b2b1908
,
0x192b2b0808080819
,
0x192b2b08192b2b2b
,
0x192b2b19082b1919
,
0x192b2b2b0808192b
,
0x192b2b2b19191908
,
0x192b2b2b192b082b
,
0x2b08080808080808
,
0x2b0808080808082b
,
0x2b08080808081919
,
0x2b08080808082b08
,
0x2b08080808190819
,
0x2b08080808191908
,
0x2b080808082b0808
,
0x2b080808082b2b2b
,
0x2b08080819080819
,
0x2b08080819081908
,
0x2b08080819190808
,
0x2b0808082b080808
,
0x2b0808082b08082b
,
0x2b0808082b2b2b08
,
0x2b0808082b2b2b2b
,
0x2b08081908080819
,
0x2b08081908081908
,
0x2b0808190808192b
,
0x2b08081908190808
,
0x2b08081919080808
,
0x2b08081919190819
,
0x2b08081919192b19
,
0x2b08082b08080808
,
0x2b08082b082b0808
,
0x2b08082b2b080808
,
0x2b08082b2b08082b
,
0x2b08082b2b2b0808
,
0x2b08082b2b2b2b08
,
0x2b08190808080819
,
0x2b08190808081908
,
0x2b08190808190808
,
0x2b0819080819082b
,
0x2b08190808191919
,
0x2b08190819080808
,
0x2b081908192b0808
,
0x2b0819082b082b19
,
0x2b08191908080808
,
0x2b08191919081908
,
0x2b0819192b2b1919
,
0x2b08192b08192b08
,
0x2b08192b192b2b2b
,
0x2b082b0808080808
,
0x2b082b0808082b08
,
0x2b082b08082b1919
,
0x2b082b0819192b2b
,
0x2b082b082b080808
,
0x2b082b082b08082b
,
0x2b082b082b2b2b08
,
0x2b082b190808192b
,
0x2b082b2b082b082b
,
0x2b082b2b2b080808
,
0x2b082b2b2b082b08
,
0x2b082b2b2b19192b
,
0x2b082b2b2b2b2b08
,
0x2b19080808080819
,
0x2b19080808081908
,
0x2b19080808190808
,
0x2b19080819080808
,
0x2b1908081919192b
,
0x2b1908082b081908
,
0x2b19081908080808
,
0x2b190819082b082b
,
0x2b190819192b1908
,
0x2b19082b1919192b
,
0x2b19082b2b082b19
,
0x2b19190808080808
,
0x2b19190808081919
,
0x2b19190819081908
,
0x2b19190819190808
,
0x2b19190819192b08
,
0x2b191919082b2b19
,
0x2b1919192b190808
,
0x2b1919192b19082b
,
0x2b19192b19080819
,
0x2b192b0819190819
,
0x2b192b082b2b192b
,
0x2b192b1919082b19
,
0x2b192b2b08191919
,
0x2b192b2b192b0808
,
0x2b2b080808080808
,
0x2b2b08080808082b
,
0x2b2b080808082b08
,
0x2b2b080808082b2b
,
0x2b2b0808082b0808
,
0x2b2b0808082b2b2b
,
0x2b2b08082b2b0808
,
0x2b2b081919190819
,
0x2b2b081919192b19
,
0x2b2b08192b2b192b
,
0x2b2b082b08080808
,
0x2b2b082b0808082b
,
0x2b2b082b08082b08
,
0x2b2b082b082b2b2b
,
0x2b2b082b2b080808
,
0x2b2b082b2b2b0808
,
0x2b2b190819080808
,
0x2b2b19082b191919
,
0x2b2b192b192b1919
,
0x2b2b192b2b192b08
,
0x2b2b2b0808082b2b
,
0x2b2b2b08082b0808
,
0x2b2b2b08082b082b
,
0x2b2b2b08082b2b08
,
0x2b2b2b082b2b0808
,
0x2b2b2b082b2b2b08
,
0x2b2b2b1908081908
,
0x2b2b2b192b081908
,
0x2b2b2b192b08192b
,
0x2b2b2b2b082b2b08
,
0x2b2b2b2b082b2b2b
,
0x2b2b2b2b2b190819
,
0x2b2b2b2b2b2b2b2b
,
};
static
const
__device__
uint64_t
iq2s_grid
[
1024
]
=
{
0x0808080808080808
,
0x080808080808082b
,
0x0808080808081919
,
0x0808080808082b08
,
0x0808080808082b2b
,
0x0808080808190819
,
0x0808080808191908
,
0x080808080819192b
,
0x0808080808192b19
,
0x08080808082b0808
,
0x08080808082b082b
,
0x08080808082b1919
,
0x08080808082b2b08
,
0x0808080819080819
,
0x0808080819081908
,
0x080808081908192b
,
0x0808080819082b19
,
0x0808080819190808
,
0x080808081919082b
,
0x0808080819191919
,
0x0808080819192b08
,
0x08080808192b0819
,
0x08080808192b1908
,
0x08080808192b192b
,
0x08080808192b2b19
,
0x080808082b080808
,
0x080808082b08082b
,
0x080808082b081919
,
0x080808082b082b08
,
0x080808082b190819
,
0x080808082b191908
,
0x080808082b2b0808
,
0x080808082b2b1919
,
0x080808082b2b2b2b
,
0x0808081908080819
,
0x0808081908081908
,
0x080808190808192b
,
0x0808081908082b19
,
0x0808081908190808
,
0x080808190819082b
,
0x0808081908191919
,
0x0808081908192b08
,
0x08080819082b0819
,
0x08080819082b1908
,
0x0808081919080808
,
0x080808191908082b
,
0x0808081919081919
,
0x0808081919082b08
,
0x0808081919190819
,
0x0808081919191908
,
0x080808191919192b
,
0x0808081919192b19
,
0x08080819192b0808
,
0x08080819192b1919
,
0x08080819192b2b08
,
0x080808192b080819
,
0x080808192b081908
,
0x080808192b190808
,
0x080808192b19082b
,
0x080808192b191919
,
0x080808192b2b0819
,
0x080808192b2b1908
,
0x0808082b08080808
,
0x0808082b0808082b
,
0x0808082b08081919
,
0x0808082b08082b08
,
0x0808082b08190819
,
0x0808082b08191908
,
0x0808082b082b0808
,
0x0808082b082b2b2b
,
0x0808082b19080819
,
0x0808082b19081908
,
0x0808082b1908192b
,
0x0808082b19082b19
,
0x0808082b19190808
,
0x0808082b19191919
,
0x0808082b2b080808
,
0x0808082b2b081919
,
0x0808082b2b082b2b
,
0x0808082b2b191908
,
0x0808082b2b2b082b
,
0x0808190808080819
,
0x0808190808081908
,
0x080819080808192b
,
0x0808190808082b19
,
0x0808190808190808
,
0x080819080819082b
,
0x0808190808191919
,
0x0808190808192b08
,
0x08081908082b0819
,
0x08081908082b1908
,
0x08081908082b192b
,
0x08081908082b2b19
,
0x0808190819080808
,
0x080819081908082b
,
0x0808190819081919
,
0x0808190819082b08
,
0x0808190819082b2b
,
0x0808190819190819
,
0x0808190819191908
,
0x080819081919192b
,
0x0808190819192b19
,
0x08081908192b0808
,
0x08081908192b082b
,
0x08081908192b1919
,
0x080819082b080819
,
0x080819082b081908
,
0x080819082b08192b
,
0x080819082b082b19
,
0x080819082b190808
,
0x080819082b191919
,
0x080819082b192b08
,
0x080819082b2b0819
,
0x080819082b2b1908
,
0x0808191908080808
,
0x080819190808082b
,
0x0808191908081919
,
0x0808191908082b08
,
0x0808191908082b2b
,
0x0808191908190819
,
0x0808191908191908
,
0x080819190819192b
,
0x0808191908192b19
,
0x08081919082b0808
,
0x08081919082b1919
,
0x08081919082b2b08
,
0x0808191919080819
,
0x0808191919081908
,
0x080819191908192b
,
0x0808191919082b19
,
0x0808191919190808
,
0x080819191919082b
,
0x0808191919191919
,
0x0808191919192b08
,
0x08081919192b0819
,
0x08081919192b1908
,
0x080819192b080808
,
0x080819192b08082b
,
0x080819192b081919
,
0x080819192b082b08
,
0x080819192b190819
,
0x080819192b191908
,
0x080819192b2b0808
,
0x0808192b08080819
,
0x0808192b08081908
,
0x0808192b0808192b
,
0x0808192b08082b19
,
0x0808192b08190808
,
0x0808192b08191919
,
0x0808192b19080808
,
0x0808192b19081919
,
0x0808192b19082b08
,
0x0808192b19190819
,
0x0808192b19191908
,
0x0808192b192b0808
,
0x0808192b2b080819
,
0x0808192b2b081908
,
0x0808192b2b190808
,
0x08082b0808080808
,
0x08082b080808082b
,
0x08082b0808081919
,
0x08082b0808082b08
,
0x08082b0808190819
,
0x08082b0808191908
,
0x08082b080819192b
,
0x08082b0808192b19
,
0x08082b08082b0808
,
0x08082b08082b1919
,
0x08082b08082b2b2b
,
0x08082b0819080819
,
0x08082b0819081908
,
0x08082b081908192b
,
0x08082b0819082b19
,
0x08082b0819190808
,
0x08082b081919082b
,
0x08082b0819191919
,
0x08082b0819192b08
,
0x08082b08192b0819
,
0x08082b08192b1908
,
0x08082b082b080808
,
0x08082b082b081919
,
0x08082b082b191908
,
0x08082b082b2b2b2b
,
0x08082b1908080819
,
0x08082b1908081908
,
0x08082b1908190808
,
0x08082b190819082b
,
0x08082b1908191919
,
0x08082b1908192b08
,
0x08082b19082b0819
,
0x08082b1919080808
,
0x08082b1919081919
,
0x08082b1919082b08
,
0x08082b1919190819
,
0x08082b1919191908
,
0x08082b19192b0808
,
0x08082b192b080819
,
0x08082b192b190808
,
0x08082b2b08080808
,
0x08082b2b08190819
,
0x08082b2b08191908
,
0x08082b2b082b082b
,
0x08082b2b082b2b08
,
0x08082b2b082b2b2b
,
0x08082b2b19190808
,
0x08082b2b2b192b19
,
0x0819080808080819
,
0x0819080808081908
,
0x081908080808192b
,
0x0819080808082b19
,
0x0819080808190808
,
0x081908080819082b
,
0x0819080808191919
,
0x0819080808192b08
,
0x08190808082b0819
,
0x08190808082b1908
,
0x08190808082b192b
,
0x0819080819080808
,
0x081908081908082b
,
0x0819080819081919
,
0x0819080819082b08
,
0x0819080819190819
,
0x0819080819191908
,
0x081908081919192b
,
0x0819080819192b19
,
0x08190808192b0808
,
0x08190808192b082b
,
0x08190808192b1919
,
0x08190808192b2b08
,
0x081908082b080819
,
0x081908082b081908
,
0x081908082b08192b
,
0x081908082b190808
,
0x081908082b191919
,
0x081908082b192b08
,
0x081908082b2b0819
,
0x081908082b2b1908
,
0x0819081908080808
,
0x081908190808082b
,
0x0819081908081919
,
0x0819081908082b08
,
0x0819081908082b2b
,
0x0819081908190819
,
0x0819081908191908
,
0x081908190819192b
,
0x0819081908192b19
,
0x08190819082b0808
,
0x08190819082b082b
,
0x08190819082b1919
,
0x08190819082b2b08
,
0x0819081919080819
,
0x0819081919081908
,
0x081908191908192b
,
0x0819081919082b19
,
0x0819081919190808
,
0x081908191919082b
,
0x0819081919191919
,
0x0819081919192b08
,
0x08190819192b0819
,
0x08190819192b1908
,
0x081908192b080808
,
0x081908192b08082b
,
0x081908192b081919
,
0x081908192b082b08
,
0x081908192b190819
,
0x081908192b191908
,
0x0819082b08080819
,
0x0819082b08081908
,
0x0819082b08082b19
,
0x0819082b08190808
,
0x0819082b08191919
,
0x0819082b082b0819
,
0x0819082b082b1908
,
0x0819082b19080808
,
0x0819082b19081919
,
0x0819082b19190819
,
0x0819082b19191908
,
0x0819082b2b080819
,
0x0819082b2b081908
,
0x0819082b2b190808
,
0x0819190808080808
,
0x081919080808082b
,
0x0819190808081919
,
0x0819190808082b08
,
0x0819190808190819
,
0x0819190808191908
,
0x081919080819192b
,
0x0819190808192b19
,
0x08191908082b0808
,
0x08191908082b1919
,
0x08191908082b2b08
,
0x0819190819080819
,
0x0819190819081908
,
0x081919081908192b
,
0x0819190819082b19
,
0x0819190819190808
,
0x081919081919082b
,
0x0819190819191919
,
0x0819190819192b08
,
0x08191908192b0819
,
0x08191908192b1908
,
0x081919082b080808
,
0x081919082b08082b
,
0x081919082b081919
,
0x081919082b082b08
,
0x081919082b190819
,
0x081919082b191908
,
0x081919082b2b0808
,
0x0819191908080819
,
0x0819191908081908
,
0x081919190808192b
,
0x0819191908082b19
,
0x0819191908190808
,
0x081919190819082b
,
0x0819191908191919
,
0x0819191908192b08
,
0x08191919082b0819
,
0x08191919082b1908
,
0x0819191919080808
,
0x081919191908082b
,
0x0819191919081919
,
0x0819191919082b08
,
0x0819191919190819
,
0x0819191919191908
,
0x08191919192b0808
,
0x081919192b080819
,
0x081919192b081908
,
0x081919192b190808
,
0x0819192b08080808
,
0x0819192b08081919
,
0x0819192b08082b08
,
0x0819192b08190819
,
0x0819192b08191908
,
0x0819192b082b0808
,
0x0819192b19080819
,
0x0819192b19081908
,
0x0819192b19190808
,
0x0819192b2b080808
,
0x0819192b2b2b2b2b
,
0x08192b0808080819
,
0x08192b0808081908
,
0x08192b080808192b
,
0x08192b0808082b19
,
0x08192b0808190808
,
0x08192b0808191919
,
0x08192b0808192b08
,
0x08192b08082b0819
,
0x08192b0819080808
,
0x08192b081908082b
,
0x08192b0819081919
,
0x08192b0819082b08
,
0x08192b0819190819
,
0x08192b0819191908
,
0x08192b08192b0808
,
0x08192b082b080819
,
0x08192b082b081908
,
0x08192b1908080808
,
0x08192b190808082b
,
0x08192b1908081919
,
0x08192b1908082b08
,
0x08192b1908190819
,
0x08192b1908191908
,
0x08192b19082b0808
,
0x08192b1919080819
,
0x08192b1919081908
,
0x08192b1919190808
,
0x08192b19192b2b19
,
0x08192b192b2b082b
,
0x08192b2b08081908
,
0x08192b2b08190808
,
0x08192b2b19080808
,
0x08192b2b1919192b
,
0x082b080808080808
,
0x082b08080808082b
,
0x082b080808081919
,
0x082b080808082b08
,
0x082b080808190819
,
0x082b080808191908
,
0x082b08080819192b
,
0x082b080808192b19
,
0x082b0808082b0808
,
0x082b0808082b1919
,
0x082b0808082b2b2b
,
0x082b080819080819
,
0x082b080819081908
,
0x082b080819190808
,
0x082b08081919082b
,
0x082b080819191919
,
0x082b0808192b1908
,
0x082b08082b080808
,
0x082b08082b082b2b
,
0x082b08082b191908
,
0x082b08082b2b2b2b
,
0x082b081908080819
,
0x082b081908081908
,
0x082b081908190808
,
0x082b08190819082b
,
0x082b081908191919
,
0x082b0819082b0819
,
0x082b081919080808
,
0x082b08191908082b
,
0x082b081919081919
,
0x082b081919190819
,
0x082b081919191908
,
0x082b0819192b0808
,
0x082b08192b080819
,
0x082b08192b081908
,
0x082b08192b190808
,
0x082b082b08080808
,
0x082b082b08082b2b
,
0x082b082b082b082b
,
0x082b082b082b2b08
,
0x082b082b082b2b2b
,
0x082b082b19081908
,
0x082b082b19190808
,
0x082b082b2b082b08
,
0x082b082b2b082b2b
,
0x082b082b2b2b2b08
,
0x082b190808080819
,
0x082b190808081908
,
0x082b19080808192b
,
0x082b190808082b19
,
0x082b190808190808
,
0x082b190808191919
,
0x082b190808192b08
,
0x082b1908082b0819
,
0x082b1908082b1908
,
0x082b190819080808
,
0x082b19081908082b
,
0x082b190819081919
,
0x082b190819082b08
,
0x082b190819190819
,
0x082b190819191908
,
0x082b1908192b0808
,
0x082b19082b080819
,
0x082b19082b081908
,
0x082b19082b190808
,
0x082b191908080808
,
0x082b191908081919
,
0x082b191908082b08
,
0x082b191908190819
,
0x082b191908191908
,
0x082b1919082b0808
,
0x082b191919080819
,
0x082b191919081908
,
0x082b191919190808
,
0x082b1919192b192b
,
0x082b19192b080808
,
0x082b192b08080819
,
0x082b192b08081908
,
0x082b192b08190808
,
0x082b192b19080808
,
0x082b192b19192b19
,
0x082b2b0808080808
,
0x082b2b0808081919
,
0x082b2b0808190819
,
0x082b2b0808191908
,
0x082b2b0819080819
,
0x082b2b0819081908
,
0x082b2b0819190808
,
0x082b2b082b082b2b
,
0x082b2b082b2b2b2b
,
0x082b2b1908080819
,
0x082b2b1908081908
,
0x082b2b1908190808
,
0x082b2b192b191919
,
0x082b2b2b08082b2b
,
0x082b2b2b082b082b
,
0x082b2b2b192b1908
,
0x082b2b2b2b082b08
,
0x082b2b2b2b082b2b
,
0x1908080808080819
,
0x1908080808081908
,
0x190808080808192b
,
0x1908080808082b19
,
0x1908080808190808
,
0x190808080819082b
,
0x1908080808191919
,
0x1908080808192b08
,
0x1908080808192b2b
,
0x19080808082b0819
,
0x19080808082b1908
,
0x19080808082b192b
,
0x1908080819080808
,
0x190808081908082b
,
0x1908080819081919
,
0x1908080819082b08
,
0x1908080819082b2b
,
0x1908080819190819
,
0x1908080819191908
,
0x190808081919192b
,
0x1908080819192b19
,
0x19080808192b0808
,
0x19080808192b082b
,
0x19080808192b1919
,
0x190808082b080819
,
0x190808082b081908
,
0x190808082b190808
,
0x190808082b191919
,
0x190808082b192b08
,
0x190808082b2b0819
,
0x190808082b2b1908
,
0x1908081908080808
,
0x190808190808082b
,
0x1908081908081919
,
0x1908081908082b08
,
0x1908081908190819
,
0x1908081908191908
,
0x190808190819192b
,
0x1908081908192b19
,
0x19080819082b0808
,
0x19080819082b082b
,
0x19080819082b1919
,
0x1908081919080819
,
0x1908081919081908
,
0x190808191908192b
,
0x1908081919082b19
,
0x1908081919190808
,
0x190808191919082b
,
0x1908081919191919
,
0x1908081919192b08
,
0x19080819192b0819
,
0x19080819192b1908
,
0x190808192b080808
,
0x190808192b08082b
,
0x190808192b081919
,
0x190808192b082b08
,
0x190808192b190819
,
0x190808192b191908
,
0x190808192b2b0808
,
0x1908082b08080819
,
0x1908082b08081908
,
0x1908082b08190808
,
0x1908082b0819082b
,
0x1908082b08191919
,
0x1908082b08192b08
,
0x1908082b082b1908
,
0x1908082b19080808
,
0x1908082b19081919
,
0x1908082b19082b08
,
0x1908082b19190819
,
0x1908082b19191908
,
0x1908082b192b0808
,
0x1908082b2b080819
,
0x1908082b2b081908
,
0x1908190808080808
,
0x190819080808082b
,
0x1908190808081919
,
0x1908190808082b08
,
0x1908190808082b2b
,
0x1908190808190819
,
0x1908190808191908
,
0x190819080819192b
,
0x1908190808192b19
,
0x19081908082b0808
,
0x19081908082b082b
,
0x19081908082b1919
,
0x19081908082b2b08
,
0x1908190819080819
,
0x1908190819081908
,
0x190819081908192b
,
0x1908190819082b19
,
0x1908190819190808
,
0x190819081919082b
,
0x1908190819191919
,
0x1908190819192b08
,
0x19081908192b0819
,
0x19081908192b1908
,
0x190819082b080808
,
0x190819082b08082b
,
0x190819082b081919
,
0x190819082b082b08
,
0x190819082b190819
,
0x190819082b191908
,
0x190819082b2b0808
,
0x1908191908080819
,
0x1908191908081908
,
0x190819190808192b
,
0x1908191908082b19
,
0x1908191908190808
,
0x190819190819082b
,
0x1908191908191919
,
0x1908191908192b08
,
0x19081919082b0819
,
0x19081919082b1908
,
0x1908191919080808
,
0x190819191908082b
,
0x1908191919081919
,
0x1908191919082b08
,
0x1908191919190819
,
0x1908191919191908
,
0x19081919192b0808
,
0x19081919192b2b2b
,
0x190819192b080819
,
0x190819192b081908
,
0x190819192b190808
,
0x1908192b08080808
,
0x1908192b0808082b
,
0x1908192b08081919
,
0x1908192b08082b08
,
0x1908192b08190819
,
0x1908192b08191908
,
0x1908192b082b0808
,
0x1908192b19080819
,
0x1908192b19081908
,
0x1908192b19190808
,
0x1908192b2b080808
,
0x1908192b2b2b1919
,
0x19082b0808080819
,
0x19082b0808081908
,
0x19082b0808082b19
,
0x19082b0808190808
,
0x19082b080819082b
,
0x19082b0808191919
,
0x19082b0808192b08
,
0x19082b08082b0819
,
0x19082b08082b1908
,
0x19082b0819080808
,
0x19082b081908082b
,
0x19082b0819081919
,
0x19082b0819082b08
,
0x19082b0819190819
,
0x19082b0819191908
,
0x19082b08192b0808
,
0x19082b082b081908
,
0x19082b082b190808
,
0x19082b1908080808
,
0x19082b190808082b
,
0x19082b1908081919
,
0x19082b1908082b08
,
0x19082b1908190819
,
0x19082b1908191908
,
0x19082b19082b0808
,
0x19082b1919080819
,
0x19082b1919081908
,
0x19082b1919190808
,
0x19082b192b080808
,
0x19082b192b19192b
,
0x19082b2b08080819
,
0x19082b2b08081908
,
0x19082b2b08190808
,
0x19082b2b19080808
,
0x1919080808080808
,
0x191908080808082b
,
0x1919080808081919
,
0x1919080808082b08
,
0x1919080808190819
,
0x1919080808191908
,
0x191908080819192b
,
0x1919080808192b19
,
0x19190808082b0808
,
0x19190808082b082b
,
0x19190808082b1919
,
0x19190808082b2b08
,
0x1919080819080819
,
0x1919080819081908
,
0x191908081908192b
,
0x1919080819082b19
,
0x1919080819190808
,
0x191908081919082b
,
0x1919080819191919
,
0x1919080819192b08
,
0x19190808192b0819
,
0x19190808192b1908
,
0x191908082b080808
,
0x191908082b08082b
,
0x191908082b081919
,
0x191908082b082b08
,
0x191908082b190819
,
0x191908082b191908
,
0x1919081908080819
,
0x1919081908081908
,
0x191908190808192b
,
0x1919081908082b19
,
0x1919081908190808
,
0x191908190819082b
,
0x1919081908191919
,
0x1919081908192b08
,
0x19190819082b0819
,
0x19190819082b1908
,
0x1919081919080808
,
0x191908191908082b
,
0x1919081919081919
,
0x1919081919082b08
,
0x1919081919190819
,
0x1919081919191908
,
0x19190819192b0808
,
0x191908192b080819
,
0x191908192b081908
,
0x191908192b190808
,
0x1919082b08080808
,
0x1919082b08081919
,
0x1919082b08082b08
,
0x1919082b08190819
,
0x1919082b08191908
,
0x1919082b082b0808
,
0x1919082b19080819
,
0x1919082b19081908
,
0x1919082b19190808
,
0x1919082b192b2b19
,
0x1919082b2b080808
,
0x1919190808080819
,
0x1919190808081908
,
0x191919080808192b
,
0x1919190808082b19
,
0x1919190808190808
,
0x191919080819082b
,
0x1919190808191919
,
0x1919190808192b08
,
0x19191908082b0819
,
0x19191908082b1908
,
0x1919190819080808
,
0x191919081908082b
,
0x1919190819081919
,
0x1919190819082b08
,
0x1919190819190819
,
0x1919190819191908
,
0x19191908192b0808
,
0x191919082b080819
,
0x191919082b081908
,
0x191919082b190808
,
0x1919191908080808
,
0x191919190808082b
,
0x1919191908081919
,
0x1919191908082b08
,
0x1919191908190819
,
0x1919191908191908
,
0x19191919082b0808
,
0x1919191919080819
,
0x1919191919081908
,
0x1919191919190808
,
0x191919192b080808
,
0x1919192b08080819
,
0x1919192b08081908
,
0x1919192b08190808
,
0x1919192b082b192b
,
0x1919192b19080808
,
0x19192b0808080808
,
0x19192b080808082b
,
0x19192b0808081919
,
0x19192b0808082b08
,
0x19192b0808190819
,
0x19192b0808191908
,
0x19192b08082b0808
,
0x19192b0819080819
,
0x19192b0819081908
,
0x19192b0819190808
,
0x19192b0819192b2b
,
0x19192b082b080808
,
0x19192b1908080819
,
0x19192b1908081908
,
0x19192b1908190808
,
0x19192b1919080808
,
0x19192b2b08080808
,
0x19192b2b08192b19
,
0x19192b2b2b081919
,
0x19192b2b2b2b2b08
,
0x192b080808080819
,
0x192b080808081908
,
0x192b08080808192b
,
0x192b080808190808
,
0x192b08080819082b
,
0x192b080808191919
,
0x192b080808192b08
,
0x192b0808082b0819
,
0x192b0808082b1908
,
0x192b080819080808
,
0x192b080819081919
,
0x192b080819082b08
,
0x192b080819190819
,
0x192b080819191908
,
0x192b0808192b0808
,
0x192b08082b081908
,
0x192b08082b190808
,
0x192b081908080808
,
0x192b08190808082b
,
0x192b081908081919
,
0x192b081908082b08
,
0x192b081908190819
,
0x192b081908191908
,
0x192b0819082b0808
,
0x192b081919080819
,
0x192b081919081908
,
0x192b081919190808
,
0x192b08192b080808
,
0x192b08192b192b19
,
0x192b082b08081908
,
0x192b082b08190808
,
0x192b082b19080808
,
0x192b082b1919192b
,
0x192b082b2b2b0819
,
0x192b190808080808
,
0x192b190808081919
,
0x192b190808082b08
,
0x192b190808190819
,
0x192b190808191908
,
0x192b1908082b0808
,
0x192b190819080819
,
0x192b190819081908
,
0x192b190819190808
,
0x192b19082b080808
,
0x192b191908080819
,
0x192b191908081908
,
0x192b191908190808
,
0x192b191919080808
,
0x192b191919082b2b
,
0x192b1919192b2b08
,
0x192b19192b19082b
,
0x192b192b08080808
,
0x192b192b2b191908
,
0x192b2b0808080819
,
0x192b2b0808081908
,
0x192b2b0808190808
,
0x192b2b08192b1919
,
0x192b2b082b192b08
,
0x192b2b1908080808
,
0x192b2b19082b2b2b
,
0x192b2b2b1908082b
,
0x192b2b2b2b2b0819
,
0x2b08080808080808
,
0x2b0808080808082b
,
0x2b08080808081919
,
0x2b08080808082b08
,
0x2b08080808190819
,
0x2b08080808191908
,
0x2b08080808192b19
,
0x2b080808082b0808
,
0x2b080808082b1919
,
0x2b08080819080819
,
0x2b08080819081908
,
0x2b08080819190808
,
0x2b0808081919082b
,
0x2b08080819191919
,
0x2b08080819192b08
,
0x2b080808192b0819
,
0x2b0808082b080808
,
0x2b0808082b081919
,
0x2b0808082b190819
,
0x2b0808082b191908
,
0x2b08081908080819
,
0x2b08081908081908
,
0x2b08081908082b19
,
0x2b08081908190808
,
0x2b0808190819082b
,
0x2b08081908191919
,
0x2b08081908192b08
,
0x2b080819082b0819
,
0x2b080819082b1908
,
0x2b08081919080808
,
0x2b0808191908082b
,
0x2b08081919081919
,
0x2b08081919082b08
,
0x2b08081919190819
,
0x2b08081919191908
,
0x2b0808192b080819
,
0x2b0808192b081908
,
0x2b0808192b190808
,
0x2b0808192b2b2b19
,
0x2b08082b08080808
,
0x2b08082b08081919
,
0x2b08082b08082b2b
,
0x2b08082b08190819
,
0x2b08082b08191908
,
0x2b08082b19080819
,
0x2b08082b19081908
,
0x2b08082b19190808
,
0x2b08190808080819
,
0x2b08190808081908
,
0x2b0819080808192b
,
0x2b08190808082b19
,
0x2b08190808190808
,
0x2b0819080819082b
,
0x2b08190808191919
,
0x2b08190808192b08
,
0x2b081908082b0819
,
0x2b08190819080808
,
0x2b0819081908082b
,
0x2b08190819081919
,
0x2b08190819082b08
,
0x2b08190819190819
,
0x2b08190819191908
,
0x2b081908192b0808
,
0x2b0819082b080819
,
0x2b0819082b081908
,
0x2b0819082b190808
,
0x2b08191908080808
,
0x2b0819190808082b
,
0x2b08191908081919
,
0x2b08191908082b08
,
0x2b08191908190819
,
0x2b08191908191908
,
0x2b081919082b0808
,
0x2b08191919080819
,
0x2b08191919081908
,
0x2b08191919190808
,
0x2b0819192b080808
,
0x2b0819192b082b2b
,
0x2b08192b08080819
,
0x2b08192b08081908
,
0x2b08192b08190808
,
0x2b08192b082b2b19
,
0x2b08192b19080808
,
0x2b082b0808080808
,
0x2b082b0808081919
,
0x2b082b0808190819
,
0x2b082b0808191908
,
0x2b082b0819080819
,
0x2b082b0819081908
,
0x2b082b0819190808
,
0x2b082b082b2b082b
,
0x2b082b1908080819
,
0x2b082b1908081908
,
0x2b082b1919080808
,
0x2b082b19192b1919
,
0x2b082b2b082b082b
,
0x2b082b2b19192b08
,
0x2b082b2b19192b2b
,
0x2b082b2b2b08082b
,
0x2b082b2b2b2b082b
,
0x2b19080808080819
,
0x2b19080808081908
,
0x2b19080808082b19
,
0x2b19080808190808
,
0x2b1908080819082b
,
0x2b19080808191919
,
0x2b19080808192b08
,
0x2b190808082b1908
,
0x2b19080819080808
,
0x2b1908081908082b
,
0x2b19080819081919
,
0x2b19080819082b08
,
0x2b19080819190819
,
0x2b19080819191908
,
0x2b190808192b0808
,
0x2b1908082b080819
,
0x2b1908082b081908
,
0x2b1908082b190808
,
0x2b19081908080808
,
0x2b19081908081919
,
0x2b19081908190819
,
0x2b19081908191908
,
0x2b19081919080819
,
0x2b19081919081908
,
0x2b19081919190808
,
0x2b19081919192b2b
,
0x2b19082b08080819
,
0x2b19082b08081908
,
0x2b19082b08190808
,
0x2b19082b19080808
,
0x2b19082b2b2b192b
,
0x2b19190808080808
,
0x2b1919080808082b
,
0x2b19190808081919
,
0x2b19190808082b08
,
0x2b19190808190819
,
0x2b19190808191908
,
0x2b191908082b0808
,
0x2b19190819080819
,
0x2b19190819081908
,
0x2b19190819190808
,
0x2b1919082b080808
,
0x2b1919082b19192b
,
0x2b19191908080819
,
0x2b19191908081908
,
0x2b19191908190808
,
0x2b19191919080808
,
0x2b1919192b192b08
,
0x2b1919192b2b0819
,
0x2b19192b08080808
,
0x2b19192b1908192b
,
0x2b19192b192b1908
,
0x2b192b0808080819
,
0x2b192b0808081908
,
0x2b192b0808190808
,
0x2b192b08082b192b
,
0x2b192b0819080808
,
0x2b192b082b2b2b19
,
0x2b192b1908080808
,
0x2b192b1919082b19
,
0x2b192b191919082b
,
0x2b192b2b2b190808
,
0x2b2b080808080808
,
0x2b2b080808081919
,
0x2b2b080808082b2b
,
0x2b2b080808191908
,
0x2b2b0808082b082b
,
0x2b2b0808082b2b2b
,
0x2b2b080819080819
,
0x2b2b080819081908
,
0x2b2b080819190808
,
0x2b2b08082b2b082b
,
0x2b2b08082b2b2b2b
,
0x2b2b081919080808
,
0x2b2b0819192b1919
,
0x2b2b082b0808082b
,
0x2b2b082b08082b2b
,
0x2b2b082b082b082b
,
0x2b2b082b082b2b08
,
0x2b2b082b082b2b2b
,
0x2b2b082b2b08082b
,
0x2b2b082b2b082b08
,
0x2b2b082b2b082b2b
,
0x2b2b082b2b2b2b08
,
0x2b2b190808080819
,
0x2b2b190808081908
,
0x2b2b190808190808
,
0x2b2b190819080808
,
0x2b2b19082b082b19
,
0x2b2b19082b2b1908
,
0x2b2b191908080808
,
0x2b2b191908192b19
,
0x2b2b192b19190819
,
0x2b2b2b0808082b2b
,
0x2b2b2b08082b2b08
,
0x2b2b2b082b2b082b
,
0x2b2b2b1919191908
,
0x2b2b2b192b08192b
,
0x2b2b2b2b08082b08
,
0x2b2b2b2b08082b2b
,
0x2b2b2b2b082b0808
,
0x2b2b2b2b082b082b
,
0x2b2b2b2b082b2b08
,
0x2b2b2b2b2b082b08
,
0x2b2b2b2b2b2b2b2b
,
};
static
const
__device__
uint32_t
iq3xxs_grid
[
256
]
=
{
0x04040404
,
0x04040414
,
0x04040424
,
0x04040c0c
,
0x04040c1c
,
0x04040c3e
,
0x04041404
,
0x04041414
,
0x04041c0c
,
0x04042414
,
0x04043e1c
,
0x04043e2c
,
0x040c040c
,
0x040c041c
,
0x040c0c04
,
0x040c0c14
,
0x040c140c
,
0x040c142c
,
0x040c1c04
,
0x040c1c14
,
0x040c240c
,
0x040c2c24
,
0x040c3e04
,
0x04140404
,
0x04140414
,
0x04140424
,
0x04140c0c
,
0x04141404
,
0x04141414
,
0x04141c0c
,
0x04141c1c
,
0x04141c3e
,
0x04142c0c
,
0x04142c3e
,
0x04143e2c
,
0x041c040c
,
0x041c043e
,
0x041c0c04
,
0x041c0c14
,
0x041c142c
,
0x041c3e04
,
0x04240c1c
,
0x04241c3e
,
0x04242424
,
0x04242c3e
,
0x04243e1c
,
0x04243e2c
,
0x042c040c
,
0x042c043e
,
0x042c1c14
,
0x042c2c14
,
0x04341c2c
,
0x04343424
,
0x043e0c04
,
0x043e0c24
,
0x043e0c34
,
0x043e241c
,
0x043e340c
,
0x0c04040c
,
0x0c04041c
,
0x0c040c04
,
0x0c040c14
,
0x0c04140c
,
0x0c04141c
,
0x0c041c04
,
0x0c041c14
,
0x0c041c24
,
0x0c04243e
,
0x0c042c04
,
0x0c0c0404
,
0x0c0c0414
,
0x0c0c0c0c
,
0x0c0c1404
,
0x0c0c1414
,
0x0c14040c
,
0x0c14041c
,
0x0c140c04
,
0x0c140c14
,
0x0c14140c
,
0x0c141c04
,
0x0c143e14
,
0x0c1c0404
,
0x0c1c0414
,
0x0c1c1404
,
0x0c1c1c0c
,
0x0c1c2434
,
0x0c1c3434
,
0x0c24040c
,
0x0c24042c
,
0x0c242c04
,
0x0c2c1404
,
0x0c2c1424
,
0x0c2c2434
,
0x0c2c3e0c
,
0x0c34042c
,
0x0c3e1414
,
0x0c3e2404
,
0x14040404
,
0x14040414
,
0x14040c0c
,
0x14040c1c
,
0x14041404
,
0x14041414
,
0x14041434
,
0x14041c0c
,
0x14042414
,
0x140c040c
,
0x140c041c
,
0x140c042c
,
0x140c0c04
,
0x140c0c14
,
0x140c140c
,
0x140c1c04
,
0x140c341c
,
0x140c343e
,
0x140c3e04
,
0x14140404
,
0x14140414
,
0x14140c0c
,
0x14140c3e
,
0x14141404
,
0x14141414
,
0x14141c3e
,
0x14142404
,
0x14142c2c
,
0x141c040c
,
0x141c0c04
,
0x141c0c24
,
0x141c3e04
,
0x141c3e24
,
0x14241c2c
,
0x14242c1c
,
0x142c041c
,
0x142c143e
,
0x142c240c
,
0x142c3e24
,
0x143e040c
,
0x143e041c
,
0x143e0c34
,
0x143e242c
,
0x1c04040c
,
0x1c040c04
,
0x1c040c14
,
0x1c04140c
,
0x1c04141c
,
0x1c042c04
,
0x1c04342c
,
0x1c043e14
,
0x1c0c0404
,
0x1c0c0414
,
0x1c0c1404
,
0x1c0c1c0c
,
0x1c0c2424
,
0x1c0c2434
,
0x1c14040c
,
0x1c14041c
,
0x1c140c04
,
0x1c14142c
,
0x1c142c14
,
0x1c143e14
,
0x1c1c0c0c
,
0x1c1c1c1c
,
0x1c241c04
,
0x1c24243e
,
0x1c243e14
,
0x1c2c0404
,
0x1c2c0434
,
0x1c2c1414
,
0x1c2c2c2c
,
0x1c340c24
,
0x1c341c34
,
0x1c34341c
,
0x1c3e1c1c
,
0x1c3e3404
,
0x24040424
,
0x24040c3e
,
0x24041c2c
,
0x24041c3e
,
0x24042c1c
,
0x24042c3e
,
0x240c3e24
,
0x24141404
,
0x24141c3e
,
0x24142404
,
0x24143404
,
0x24143434
,
0x241c043e
,
0x241c242c
,
0x24240424
,
0x24242c0c
,
0x24243424
,
0x242c142c
,
0x242c241c
,
0x242c3e04
,
0x243e042c
,
0x243e0c04
,
0x243e0c14
,
0x243e1c04
,
0x2c040c14
,
0x2c04240c
,
0x2c043e04
,
0x2c0c0404
,
0x2c0c0434
,
0x2c0c1434
,
0x2c0c2c2c
,
0x2c140c24
,
0x2c141c14
,
0x2c143e14
,
0x2c1c0414
,
0x2c1c2c1c
,
0x2c240c04
,
0x2c24141c
,
0x2c24143e
,
0x2c243e14
,
0x2c2c0414
,
0x2c2c1c0c
,
0x2c342c04
,
0x2c3e1424
,
0x2c3e2414
,
0x34041424
,
0x34042424
,
0x34042434
,
0x34043424
,
0x340c140c
,
0x340c340c
,
0x34140c3e
,
0x34143424
,
0x341c1c04
,
0x341c1c34
,
0x34242424
,
0x342c042c
,
0x342c2c14
,
0x34341c1c
,
0x343e041c
,
0x343e140c
,
0x3e04041c
,
0x3e04042c
,
0x3e04043e
,
0x3e040c04
,
0x3e041c14
,
0x3e042c14
,
0x3e0c1434
,
0x3e0c2404
,
0x3e140c14
,
0x3e14242c
,
0x3e142c14
,
0x3e1c0404
,
0x3e1c0c2c
,
0x3e1c1c1c
,
0x3e1c3404
,
0x3e24140c
,
0x3e24240c
,
0x3e2c0404
,
0x3e2c0414
,
0x3e2c1424
,
0x3e341c04
,
};
static
const
__device__
uint32_t
iq3xs_grid
[
512
]
=
{
0x04040404
,
0x0404040c
,
0x04040414
,
0x0404042c
,
0x0404043e
,
0x04040c04
,
0x04040c0c
,
0x04040c14
,
0x04040c24
,
0x04040c34
,
0x04041404
,
0x0404140c
,
0x0404142c
,
0x04041c1c
,
0x04042404
,
0x04042414
,
0x0404242c
,
0x0404243e
,
0x04042c0c
,
0x04042c1c
,
0x04043404
,
0x04043414
,
0x04043e0c
,
0x04043e24
,
0x04043e3e
,
0x040c0404
,
0x040c040c
,
0x040c0414
,
0x040c0424
,
0x040c0c04
,
0x040c0c0c
,
0x040c0c2c
,
0x040c1404
,
0x040c141c
,
0x040c143e
,
0x040c1c0c
,
0x040c1c2c
,
0x040c2424
,
0x040c340c
,
0x040c342c
,
0x040c3e14
,
0x04140404
,
0x0414040c
,
0x0414042c
,
0x0414043e
,
0x04140c04
,
0x04140c1c
,
0x04140c34
,
0x0414140c
,
0x0414142c
,
0x04141c04
,
0x04141c24
,
0x04142414
,
0x0414242c
,
0x0414243e
,
0x04142c0c
,
0x04142c1c
,
0x04143e04
,
0x04143e1c
,
0x041c041c
,
0x041c0c0c
,
0x041c0c2c
,
0x041c1404
,
0x041c1414
,
0x041c1c0c
,
0x041c1c1c
,
0x041c1c34
,
0x041c2424
,
0x041c2c04
,
0x041c2c14
,
0x041c343e
,
0x041c3e0c
,
0x041c3e2c
,
0x04240404
,
0x04240c1c
,
0x04240c3e
,
0x0424140c
,
0x04241424
,
0x04241c14
,
0x04242404
,
0x0424241c
,
0x04242c0c
,
0x04243e04
,
0x042c0414
,
0x042c0424
,
0x042c1404
,
0x042c1414
,
0x042c1434
,
0x042c1c1c
,
0x042c240c
,
0x042c242c
,
0x042c243e
,
0x042c3434
,
0x042c3e1c
,
0x04340434
,
0x04340c0c
,
0x04340c1c
,
0x04341c0c
,
0x04342c14
,
0x04343e0c
,
0x043e0404
,
0x043e0414
,
0x043e0424
,
0x043e1404
,
0x043e1414
,
0x043e1434
,
0x043e1c1c
,
0x043e2c04
,
0x043e2c24
,
0x0c040404
,
0x0c04040c
,
0x0c040414
,
0x0c040424
,
0x0c040c04
,
0x0c040c0c
,
0x0c040c1c
,
0x0c040c2c
,
0x0c040c3e
,
0x0c041404
,
0x0c041414
,
0x0c041c0c
,
0x0c041c24
,
0x0c041c34
,
0x0c042c24
,
0x0c042c34
,
0x0c04340c
,
0x0c043e14
,
0x0c0c0404
,
0x0c0c040c
,
0x0c0c041c
,
0x0c0c0434
,
0x0c0c0c04
,
0x0c0c0c24
,
0x0c0c140c
,
0x0c0c1c04
,
0x0c0c1c1c
,
0x0c0c240c
,
0x0c0c2c04
,
0x0c0c2c14
,
0x0c0c3e04
,
0x0c0c3e34
,
0x0c140404
,
0x0c140c14
,
0x0c140c2c
,
0x0c140c3e
,
0x0c141404
,
0x0c141424
,
0x0c141c14
,
0x0c142404
,
0x0c14241c
,
0x0c142c2c
,
0x0c143404
,
0x0c143e14
,
0x0c1c040c
,
0x0c1c0424
,
0x0c1c043e
,
0x0c1c0c04
,
0x0c1c0c1c
,
0x0c1c140c
,
0x0c1c143e
,
0x0c1c1c04
,
0x0c1c1c24
,
0x0c1c240c
,
0x0c1c3414
,
0x0c1c3e04
,
0x0c24041c
,
0x0c24042c
,
0x0c240c14
,
0x0c240c24
,
0x0c241c0c
,
0x0c241c1c
,
0x0c242414
,
0x0c242434
,
0x0c242c04
,
0x0c242c24
,
0x0c2c040c
,
0x0c2c0c04
,
0x0c2c0c1c
,
0x0c2c140c
,
0x0c2c1c04
,
0x0c2c1c14
,
0x0c2c2c0c
,
0x0c341404
,
0x0c341424
,
0x0c34143e
,
0x0c342424
,
0x0c342434
,
0x0c3e040c
,
0x0c3e041c
,
0x0c3e0c04
,
0x0c3e0c14
,
0x0c3e140c
,
0x0c3e1c2c
,
0x0c3e240c
,
0x0c3e3414
,
0x0c3e3e04
,
0x14040404
,
0x1404040c
,
0x1404041c
,
0x1404042c
,
0x1404043e
,
0x14040c04
,
0x14040c14
,
0x14040c24
,
0x14040c34
,
0x1404140c
,
0x1404141c
,
0x1404143e
,
0x14041c04
,
0x14041c14
,
0x1404240c
,
0x1404241c
,
0x1404242c
,
0x14042c04
,
0x14042c14
,
0x1404343e
,
0x14043e04
,
0x14043e1c
,
0x14043e2c
,
0x140c0404
,
0x140c0414
,
0x140c0c04
,
0x140c0c1c
,
0x140c0c3e
,
0x140c1414
,
0x140c142c
,
0x140c1c0c
,
0x140c1c24
,
0x140c2414
,
0x140c2c0c
,
0x1414040c
,
0x14140424
,
0x1414043e
,
0x1414140c
,
0x1414141c
,
0x14141c04
,
0x14141c3e
,
0x1414240c
,
0x14142c1c
,
0x14142c3e
,
0x14143e0c
,
0x14143e24
,
0x141c0404
,
0x141c0414
,
0x141c042c
,
0x141c0c0c
,
0x141c1414
,
0x141c1424
,
0x141c1c0c
,
0x141c1c1c
,
0x141c2414
,
0x141c2c04
,
0x141c3434
,
0x1424040c
,
0x1424043e
,
0x14241404
,
0x1424141c
,
0x14241c14
,
0x14241c2c
,
0x1424240c
,
0x14243e14
,
0x14243e2c
,
0x142c0424
,
0x142c0c0c
,
0x142c1414
,
0x142c1c3e
,
0x142c2404
,
0x142c2c1c
,
0x142c3e04
,
0x14340404
,
0x14340414
,
0x1434043e
,
0x1434140c
,
0x14342c2c
,
0x1434340c
,
0x143e042c
,
0x143e0c0c
,
0x143e1434
,
0x143e1c04
,
0x143e241c
,
0x143e2c04
,
0x1c040414
,
0x1c040c0c
,
0x1c040c1c
,
0x1c040c2c
,
0x1c040c3e
,
0x1c041414
,
0x1c041c0c
,
0x1c041c1c
,
0x1c041c2c
,
0x1c042414
,
0x1c042424
,
0x1c04243e
,
0x1c042c0c
,
0x1c04341c
,
0x1c043e0c
,
0x1c0c040c
,
0x1c0c041c
,
0x1c0c042c
,
0x1c0c0c24
,
0x1c0c140c
,
0x1c0c141c
,
0x1c0c2404
,
0x1c0c3404
,
0x1c0c3e14
,
0x1c0c3e34
,
0x1c140404
,
0x1c140c14
,
0x1c141404
,
0x1c141c14
,
0x1c141c24
,
0x1c142c04
,
0x1c1c040c
,
0x1c1c0c04
,
0x1c1c0c24
,
0x1c1c140c
,
0x1c1c141c
,
0x1c1c143e
,
0x1c1c1c04
,
0x1c1c240c
,
0x1c1c241c
,
0x1c1c243e
,
0x1c1c2c2c
,
0x1c1c3e1c
,
0x1c24041c
,
0x1c240c0c
,
0x1c240c34
,
0x1c241414
,
0x1c241c0c
,
0x1c242c14
,
0x1c243404
,
0x1c243424
,
0x1c2c040c
,
0x1c2c0c04
,
0x1c2c0c14
,
0x1c2c142c
,
0x1c2c1c14
,
0x1c2c2424
,
0x1c2c2c34
,
0x1c2c3e1c
,
0x1c340c34
,
0x1c34240c
,
0x1c3e040c
,
0x1c3e041c
,
0x1c3e1404
,
0x1c3e1414
,
0x1c3e1c2c
,
0x24040404
,
0x24040424
,
0x24040c14
,
0x24041404
,
0x24041424
,
0x2404143e
,
0x24041c14
,
0x2404240c
,
0x24042c04
,
0x24043e04
,
0x240c0414
,
0x240c043e
,
0x240c0c0c
,
0x240c0c1c
,
0x240c1414
,
0x240c1c04
,
0x240c1c2c
,
0x240c241c
,
0x240c2c0c
,
0x240c2c2c
,
0x2414040c
,
0x2414041c
,
0x24140c04
,
0x24140c2c
,
0x2414140c
,
0x24141c1c
,
0x24142404
,
0x24142c3e
,
0x24143414
,
0x24143e04
,
0x241c0424
,
0x241c0c0c
,
0x241c0c1c
,
0x241c1404
,
0x241c1414
,
0x241c1c0c
,
0x241c1c2c
,
0x24240404
,
0x24240414
,
0x24241424
,
0x24241c3e
,
0x24242404
,
0x24243e0c
,
0x242c042c
,
0x242c043e
,
0x242c140c
,
0x242c3414
,
0x24340c1c
,
0x24341c24
,
0x24343404
,
0x243e0c04
,
0x243e0c2c
,
0x243e1c04
,
0x243e241c
,
0x243e2c0c
,
0x2c040414
,
0x2c040c04
,
0x2c040c24
,
0x2c041414
,
0x2c042404
,
0x2c042424
,
0x2c04243e
,
0x2c042c14
,
0x2c043434
,
0x2c043e24
,
0x2c0c040c
,
0x2c0c041c
,
0x2c0c042c
,
0x2c0c0c14
,
0x2c0c140c
,
0x2c0c1c14
,
0x2c0c3e14
,
0x2c140404
,
0x2c140c0c
,
0x2c14141c
,
0x2c141c04
,
0x2c141c34
,
0x2c142c1c
,
0x2c1c0414
,
0x2c1c043e
,
0x2c1c0c04
,
0x2c1c143e
,
0x2c1c2424
,
0x2c1c2c0c
,
0x2c1c342c
,
0x2c1c3e1c
,
0x2c24040c
,
0x2c240424
,
0x2c241404
,
0x2c241c14
,
0x2c242434
,
0x2c2c0c14
,
0x2c2c1434
,
0x2c2c2c0c
,
0x2c2c2c1c
,
0x2c342414
,
0x2c3e0414
,
0x2c3e0424
,
0x2c3e1414
,
0x34040c0c
,
0x34040c1c
,
0x34040c2c
,
0x34041c0c
,
0x34041c1c
,
0x34043404
,
0x340c0404
,
0x340c1404
,
0x340c143e
,
0x340c3424
,
0x34140c14
,
0x34141c24
,
0x34142414
,
0x34142c2c
,
0x34143414
,
0x34143e04
,
0x341c0404
,
0x341c0c24
,
0x341c140c
,
0x341c2404
,
0x3424142c
,
0x3424241c
,
0x34243414
,
0x342c0404
,
0x342c041c
,
0x342c1c24
,
0x342c3404
,
0x3434042c
,
0x34342404
,
0x343e0c0c
,
0x343e0c1c
,
0x3e040404
,
0x3e040424
,
0x3e04043e
,
0x3e041404
,
0x3e041414
,
0x3e041c34
,
0x3e042404
,
0x3e042c24
,
0x3e043414
,
0x3e0c0414
,
0x3e0c0c0c
,
0x3e0c1424
,
0x3e0c241c
,
0x3e0c242c
,
0x3e14040c
,
0x3e140424
,
0x3e140c04
,
0x3e140c34
,
0x3e14140c
,
0x3e141c04
,
0x3e142c0c
,
0x3e1c0414
,
0x3e1c1c14
,
0x3e1c1c2c
,
0x3e1c2c1c
,
0x3e24040c
,
0x3e24042c
,
0x3e240c1c
,
0x3e241404
,
0x3e242c04
,
0x3e2c1414
,
0x3e2c2414
,
0x3e340414
,
0x3e341c0c
,
0x3e3e0404
,
};
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
static
const
__device__
uint64_t
iq1s_grid_gpu
[
2048
]
=
{
0x00000000
,
0x00000002
,
0x00000101
,
0x00000200
,
0x00000202
,
0x00010001
,
0x00010101
,
0x00020000
,
0x00020002
,
0x00020200
,
0x00020202
,
0x01000101
,
0x01010001
,
0x01010100
,
0x01010102
,
0x01020101
,
0x02000000
,
0x02000002
,
0x02000200
,
0x02000202
,
0x02010101
,
0x02020000
,
0x02020002
,
0x02020200
,
0x02020202
,
0x00000110
,
0x00000111
,
0x00010011
,
0x00010110
,
0x00010112
,
0x00010211
,
0x00010212
,
0x00020111
,
0x01000011
,
0x01000112
,
0x01000211
,
0x01010012
,
0x01010111
,
0x01010212
,
0x01020011
,
0x01020110
,
0x01020112
,
0x01020210
,
0x02000111
,
0x02010011
,
0x02010110
,
0x02010112
,
0x02020111
,
0x00000020
,
0x00000022
,
0x00000220
,
0x00000222
,
0x00010121
,
0x00020020
,
0x00020022
,
0x00020220
,
0x00020222
,
0x01000121
,
0x01010021
,
0x01010221
,
0x01020120
,
0x01020221
,
0x02000020
,
0x02000022
,
0x02000220
,
0x02000222
,
0x02010021
,
0x02010121
,
0x02010221
,
0x02020020
,
0x02020022
,
0x02020220
,
0x02020222
,
0x00011001
,
0x00011100
,
0x00011102
,
0x00021101
,
0x01001001
,
0x01001201
,
0x01011101
,
0x01011202
,
0x01021100
,
0x01021101
,
0x02011001
,
0x02011201
,
0x02021101
,
0x00001011
,
0x00001110
,
0x00001111
,
0x00001112
,
0x00011111
,
0x00011210
,
0x00011212
,
0x00021211
,
0x01001010
,
0x01001111
,
0x01001212
,
0x01011010
,
0x01011011
,
0x01011110
,
0x01011111
,
0x01011112
,
0x01011211
,
0x01021010
,
0x01021012
,
0x01021111
,
0x01021210
,
0x01021212
,
0x02001011
,
0x02011011
,
0x02011111
,
0x02011210
,
0x02011212
,
0x02021011
,
0x02021110
,
0x02021111
,
0x02021112
,
0x02021211
,
0x00011120
,
0x00011221
,
0x01001021
,
0x01001120
,
0x01011020
,
0x01011022
,
0x01011121
,
0x01011220
,
0x01021020
,
0x01021021
,
0x01021122
,
0x01021221
,
0x02001121
,
0x02011021
,
0x02011120
,
0x02011221
,
0x00002000
,
0x00002002
,
0x00002200
,
0x00002202
,
0x00012101
,
0x00022000
,
0x00022002
,
0x00022200
,
0x00022202
,
0x01002101
,
0x01012001
,
0x01012102
,
0x01022101
,
0x02002000
,
0x02002002
,
0x02002200
,
0x02002202
,
0x02012101
,
0x02022000
,
0x02022002
,
0x02022200
,
0x02022202
,
0x00002111
,
0x00012011
,
0x00012110
,
0x00012211
,
0x00022110
,
0x00022111
,
0x01002011
,
0x01012010
,
0x01012011
,
0x01012111
,
0x01022011
,
0x01022110
,
0x01022211
,
0x02012011
,
0x02012110
,
0x02012112
,
0x02012211
,
0x02022111
,
0x00002020
,
0x00002022
,
0x00002220
,
0x00002222
,
0x00012121
,
0x00022020
,
0x00022022
,
0x00022220
,
0x00022222
,
0x01002121
,
0x01012021
,
0x01012221
,
0x01022021
,
0x01022121
,
0x02002020
,
0x02002022
,
0x02002121
,
0x02002220
,
0x02002222
,
0x02012121
,
0x02022020
,
0x02022022
,
0x02022220
,
0x02022222
,
0x00110000
,
0x00110001
,
0x00110100
,
0x00110201
,
0x00120100
,
0x00120101
,
0x01100001
,
0x01100100
,
0x01110000
,
0x01110101
,
0x01110200
,
0x01120001
,
0x01120100
,
0x01120101
,
0x01120201
,
0x02110001
,
0x02110100
,
0x02110102
,
0x02120001
,
0x02120101
,
0x00100011
,
0x00100110
,
0x00100112
,
0x00100211
,
0x00110010
,
0x00110012
,
0x00110111
,
0x00110210
,
0x00120011
,
0x00120110
,
0x00120211
,
0x01100111
,
0x01100212
,
0x01110010
,
0x01110011
,
0x01110012
,
0x01110110
,
0x01110111
,
0x01110112
,
0x01110211
,
0x01120010
,
0x01120111
,
0x02100110
,
0x02110012
,
0x02110111
,
0x02120011
,
0x02120110
,
0x00110021
,
0x00110120
,
0x00110122
,
0x00120121
,
0x01100020
,
0x01100122
,
0x01100221
,
0x01110022
,
0x01110121
,
0x01110220
,
0x01110222
,
0x01120120
,
0x01120122
,
0x02100121
,
0x02110021
,
0x02110120
,
0x02110122
,
0x02120121
,
0x00101001
,
0x00101102
,
0x00101201
,
0x00111100
,
0x00111101
,
0x00111200
,
0x00111201
,
0x00121001
,
0x00121102
,
0x01101001
,
0x01101101
,
0x01101102
,
0x01101200
,
0x01101202
,
0x01111001
,
0x01111100
,
0x01111101
,
0x01111102
,
0x01111201
,
0x01121002
,
0x01121101
,
0x01121200
,
0x02101100
,
0x02101201
,
0x02111000
,
0x02111100
,
0x02111101
,
0x02111200
,
0x02111201
,
0x02111202
,
0x02121001
,
0x02121100
,
0x02121101
,
0x02121201
,
0x00101012
,
0x00101111
,
0x00101212
,
0x00111011
,
0x00111110
,
0x00111111
,
0x00111112
,
0x00111211
,
0x00121010
,
0x00121012
,
0x00121111
,
0x00121210
,
0x00121212
,
0x01101011
,
0x01101110
,
0x01101111
,
0x01101112
,
0x01111011
,
0x01111012
,
0x01111110
,
0x01111111
,
0x01111112
,
0x01111211
,
0x01111212
,
0x01121011
,
0x01121110
,
0x01121111
,
0x01121112
,
0x01121211
,
0x02101010
,
0x02101012
,
0x02101110
,
0x02101111
,
0x02101210
,
0x02101212
,
0x02111010
,
0x02111011
,
0x02111110
,
0x02111111
,
0x02111112
,
0x02111211
,
0x02111212
,
0x02121010
,
0x02121012
,
0x02121111
,
0x00101021
,
0x00101120
,
0x00101121
,
0x00101122
,
0x00111121
,
0x00111122
,
0x00111220
,
0x00111222
,
0x00121021
,
0x00121122
,
0x01101020
,
0x01101022
,
0x01101120
,
0x01101121
,
0x01101220
,
0x01101222
,
0x01111021
,
0x01111121
,
0x01111122
,
0x01111220
,
0x01111221
,
0x01121021
,
0x01121120
,
0x01121121
,
0x01121220
,
0x01121221
,
0x01121222
,
0x02101122
,
0x02101222
,
0x02111022
,
0x02111121
,
0x02121120
,
0x02121221
,
0x00112001
,
0x00112102
,
0x00122101
,
0x01102001
,
0x01102100
,
0x01102102
,
0x01102201
,
0x01112000
,
0x01112101
,
0x01112200
,
0x01112202
,
0x01122000
,
0x01122001
,
0x01122100
,
0x01122102
,
0x01122201
,
0x02102101
,
0x02112001
,
0x02112100
,
0x02122101
,
0x00112010
,
0x00112012
,
0x00112111
,
0x00112212
,
0x00122011
,
0x00122111
,
0x01102012
,
0x01102110
,
0x01102111
,
0x01102210
,
0x01112011
,
0x01112110
,
0x01112111
,
0x01112112
,
0x01112211
,
0x01112212
,
0x01122010
,
0x01122111
,
0x01122212
,
0x02102211
,
0x02112011
,
0x02112012
,
0x02112111
,
0x02112210
,
0x02122011
,
0x02122112
,
0x02122211
,
0x00102221
,
0x00112122
,
0x00122120
,
0x00122122
,
0x01102120
,
0x01102122
,
0x01102221
,
0x01112020
,
0x01112022
,
0x01112121
,
0x01112220
,
0x01122021
,
0x01122122
,
0x01122221
,
0x02102121
,
0x02112021
,
0x02112122
,
0x02112222
,
0x00200000
,
0x00200002
,
0x00200200
,
0x00200202
,
0x00210101
,
0x00220000
,
0x00220002
,
0x00220101
,
0x00220200
,
0x00220202
,
0x01200101
,
0x01210001
,
0x01210201
,
0x01220001
,
0x01220101
,
0x02200000
,
0x02200002
,
0x02200200
,
0x02200202
,
0x02210101
,
0x02220000
,
0x02220002
,
0x02220101
,
0x02220200
,
0x02220202
,
0x00200111
,
0x00210011
,
0x00210110
,
0x00210211
,
0x00220111
,
0x01200012
,
0x01200110
,
0x01200211
,
0x01210111
,
0x01210210
,
0x01210212
,
0x01220011
,
0x01220110
,
0x01220111
,
0x01220112
,
0x02200111
,
0x02210010
,
0x02210112
,
0x02210211
,
0x02220111
,
0x00200021
,
0x00200220
,
0x00200222
,
0x00210021
,
0x00210121
,
0x00220020
,
0x00220022
,
0x00220220
,
0x00220222
,
0x01200121
,
0x01210021
,
0x01210122
,
0x01210221
,
0x01220121
,
0x02200021
,
0x02200220
,
0x02200222
,
0x02210021
,
0x02210121
,
0x02220020
,
0x02220022
,
0x02220220
,
0x02220222
,
0x00201101
,
0x00211100
,
0x00211102
,
0x00211201
,
0x00221101
,
0x01201100
,
0x01201101
,
0x01201102
,
0x01201201
,
0x01211002
,
0x01211101
,
0x01211200
,
0x01211202
,
0x01221102
,
0x02201101
,
0x02211001
,
0x02211100
,
0x02211201
,
0x02221001
,
0x02221101
,
0x00201211
,
0x00211111
,
0x00221011
,
0x00221211
,
0x01201010
,
0x01201111
,
0x01201210
,
0x01211011
,
0x01211110
,
0x01211111
,
0x01211211
,
0x01221012
,
0x01221111
,
0x01221210
,
0x02201211
,
0x02211010
,
0x02211110
,
0x02211111
,
0x02211210
,
0x02211212
,
0x02221011
,
0x02221110
,
0x02221112
,
0x02221211
,
0x00201121
,
0x00211020
,
0x00211022
,
0x00211221
,
0x00221121
,
0x01201021
,
0x01201221
,
0x01211121
,
0x01221020
,
0x01221021
,
0x01221221
,
0x02201120
,
0x02201122
,
0x02211020
,
0x02211222
,
0x00202000
,
0x00202002
,
0x00202200
,
0x00202202
,
0x00212101
,
0x00222000
,
0x00222002
,
0x00222200
,
0x00222202
,
0x01202101
,
0x01212001
,
0x01212100
,
0x01222101
,
0x02202000
,
0x02202002
,
0x02202200
,
0x02202202
,
0x02222000
,
0x02222002
,
0x02222200
,
0x02222202
,
0x00202211
,
0x00212011
,
0x00212110
,
0x00212211
,
0x00222111
,
0x01202112
,
0x01202211
,
0x01212012
,
0x01212111
,
0x01222011
,
0x01222110
,
0x01222112
,
0x01222211
,
0x02202111
,
0x02212010
,
0x02212112
,
0x02212211
,
0x02222110
,
0x02222111
,
0x00202020
,
0x00202022
,
0x00202220
,
0x00202222
,
0x00222020
,
0x00222022
,
0x00222220
,
0x00222222
,
0x01202121
,
0x01212021
,
0x01212122
,
0x01212221
,
0x01222121
,
0x02202020
,
0x02202022
,
0x02202220
,
0x02202222
,
0x02212121
,
0x02222020
,
0x02222022
,
0x02222220
,
0x02222222
,
0x10000101
,
0x10010001
,
0x10010102
,
0x10020101
,
0x11000201
,
0x11010002
,
0x11010101
,
0x11010200
,
0x11010202
,
0x11020001
,
0x11020100
,
0x11020102
,
0x12010100
,
0x12010201
,
0x12020001
,
0x12020102
,
0x10000010
,
0x10000011
,
0x10000110
,
0x10000112
,
0x10000211
,
0x10010012
,
0x10010111
,
0x10010112
,
0x10010210
,
0x10010212
,
0x10020011
,
0x10020112
,
0x10020211
,
0x11000111
,
0x11000210
,
0x11000212
,
0x11010011
,
0x11010110
,
0x11010111
,
0x11010112
,
0x11010211
,
0x11010212
,
0x11020111
,
0x11020210
,
0x11020212
,
0x12000011
,
0x12000110
,
0x12000112
,
0x12010010
,
0x12010012
,
0x12010111
,
0x12020010
,
0x12020011
,
0x12020012
,
0x10000121
,
0x10010021
,
0x10010120
,
0x10010122
,
0x10020121
,
0x11000021
,
0x11010022
,
0x11010121
,
0x11010222
,
0x11020120
,
0x11020221
,
0x12000221
,
0x12010120
,
0x12020121
,
0x10001001
,
0x10011101
,
0x10011201
,
0x10021201
,
0x11001101
,
0x11001200
,
0x11001202
,
0x11011001
,
0x11011100
,
0x11011101
,
0x11011102
,
0x11021001
,
0x11021002
,
0x11021101
,
0x11021200
,
0x11021202
,
0x12001001
,
0x12001102
,
0x12001201
,
0x12011000
,
0x12011002
,
0x12011101
,
0x12021000
,
0x12021001
,
0x12021201
,
0x10001011
,
0x10001012
,
0x10001111
,
0x10001212
,
0x10011011
,
0x10011110
,
0x10011111
,
0x10011112
,
0x10011211
,
0x10021010
,
0x10021111
,
0x10021212
,
0x11001011
,
0x11001110
,
0x11001111
,
0x11001112
,
0x11001211
,
0x11011010
,
0x11011011
,
0x11011110
,
0x11011111
,
0x11011112
,
0x11011210
,
0x11011211
,
0x11021011
,
0x11021110
,
0x11021111
,
0x11021112
,
0x11021211
,
0x12001012
,
0x12001110
,
0x12001111
,
0x12001210
,
0x12011011
,
0x12011110
,
0x12011111
,
0x12011112
,
0x12011211
,
0x12011212
,
0x12021111
,
0x12021210
,
0x12021212
,
0x10001021
,
0x10001121
,
0x10001221
,
0x10011120
,
0x10011121
,
0x10011220
,
0x10011222
,
0x10021021
,
0x10021120
,
0x10021221
,
0x11001020
,
0x11001022
,
0x11001121
,
0x11001220
,
0x11011020
,
0x11011021
,
0x11011022
,
0x11011121
,
0x11011122
,
0x11011221
,
0x11021022
,
0x11021121
,
0x11021220
,
0x12001021
,
0x12001121
,
0x12001222
,
0x12011120
,
0x12011121
,
0x12021021
,
0x12021120
,
0x12021122
,
0x10002101
,
0x10012001
,
0x10012101
,
0x10012202
,
0x10022101
,
0x11002002
,
0x11002201
,
0x11012000
,
0x11012101
,
0x11012200
,
0x11022001
,
0x11022100
,
0x11022102
,
0x11022201
,
0x12002101
,
0x12012001
,
0x12012100
,
0x12012102
,
0x12012201
,
0x12022101
,
0x10002011
,
0x10002111
,
0x10002112
,
0x10002212
,
0x10012010
,
0x10012110
,
0x10012111
,
0x10012210
,
0x10022011
,
0x10022110
,
0x10022112
,
0x11002010
,
0x11002111
,
0x11002212
,
0x11012011
,
0x11012012
,
0x11012110
,
0x11012111
,
0x11012112
,
0x11012211
,
0x11022010
,
0x11022012
,
0x11022111
,
0x11022112
,
0x11022212
,
0x12002112
,
0x12002211
,
0x12012012
,
0x12012111
,
0x12012112
,
0x12012210
,
0x12022011
,
0x12022110
,
0x12022112
,
0x12022211
,
0x10012122
,
0x11002120
,
0x11002122
,
0x11002221
,
0x11012121
,
0x11012220
,
0x11012222
,
0x11022120
,
0x11022221
,
0x12012120
,
0x12022121
,
0x10100001
,
0x10100100
,
0x10100101
,
0x10100102
,
0x10100201
,
0x10110002
,
0x10110101
,
0x10110202
,
0x10120001
,
0x10120100
,
0x10120201
,
0x11100000
,
0x11100101
,
0x11100200
,
0x11110001
,
0x11110100
,
0x11110101
,
0x11110102
,
0x11110201
,
0x11120101
,
0x11120200
,
0x12100102
,
0x12100201
,
0x12110101
,
0x12110200
,
0x12120000
,
0x12120001
,
0x12120102
,
0x12120201
,
0x10100111
,
0x10100210
,
0x10100211
,
0x10100212
,
0x10110011
,
0x10110110
,
0x10110111
,
0x10110112
,
0x10110210
,
0x10110211
,
0x10120010
,
0x10120111
,
0x10120112
,
0x10120210
,
0x10120212
,
0x11100011
,
0x11100110
,
0x11100111
,
0x11100112
,
0x11100211
,
0x11110010
,
0x11110011
,
0x11110012
,
0x11110110
,
0x11110111
,
0x11110112
,
0x11110210
,
0x11110211
,
0x11110212
,
0x11120011
,
0x11120110
,
0x11120111
,
0x11120112
,
0x11120211
,
0x12100012
,
0x12100111
,
0x12110011
,
0x12110110
,
0x12110111
,
0x12110112
,
0x12110211
,
0x12120010
,
0x12120111
,
0x12120212
,
0x10100021
,
0x10100122
,
0x10110022
,
0x10110121
,
0x10110222
,
0x10120021
,
0x10120120
,
0x11100022
,
0x11100121
,
0x11100222
,
0x11110021
,
0x11110120
,
0x11110121
,
0x11110122
,
0x11110221
,
0x11120022
,
0x11120121
,
0x12100121
,
0x12110020
,
0x12110022
,
0x12110121
,
0x12110221
,
0x12110222
,
0x12120120
,
0x10101100
,
0x10101101
,
0x10111001
,
0x10111100
,
0x10111101
,
0x10111102
,
0x10111200
,
0x10111201
,
0x10121001
,
0x10121101
,
0x10121200
,
0x10121202
,
0x11101001
,
0x11101100
,
0x11101101
,
0x11101102
,
0x11101201
,
0x11101202
,
0x11111000
,
0x11111001
,
0x11111100
,
0x11111101
,
0x11111102
,
0x11111200
,
0x11111201
,
0x11111202
,
0x11121001
,
0x11121002
,
0x11121100
,
0x11121101
,
0x11121102
,
0x11121201
,
0x12101000
,
0x12101200
,
0x12101202
,
0x12111001
,
0x12111100
,
0x12111101
,
0x12111102
,
0x12111201
,
0x12121001
,
0x12121100
,
0x12121101
,
0x12121202
,
0x10101011
,
0x10101012
,
0x10101110
,
0x10101111
,
0x10101112
,
0x10101211
,
0x10111010
,
0x10111011
,
0x10111012
,
0x10111110
,
0x10111111
,
0x10111112
,
0x10111211
,
0x10111212
,
0x10121011
,
0x10121110
,
0x10121111
,
0x10121112
,
0x10121211
,
0x11101010
,
0x11101011
,
0x11101012
,
0x11101110
,
0x11101111
,
0x11101112
,
0x11101210
,
0x11101211
,
0x11111010
,
0x11111011
,
0x11111012
,
0x11111110
,
0x11111111
,
0x11111112
,
0x11111210
,
0x11111211
,
0x11111212
,
0x11121010
,
0x11121011
,
0x11121110
,
0x11121111
,
0x11121112
,
0x11121210
,
0x11121211
,
0x11121212
,
0x12101011
,
0x12101110
,
0x12101111
,
0x12101211
,
0x12101212
,
0x12111010
,
0x12111011
,
0x12111110
,
0x12111111
,
0x12111112
,
0x12111210
,
0x12111211
,
0x12121011
,
0x12121110
,
0x12121111
,
0x12121112
,
0x12121211
,
0x10101020
,
0x10101021
,
0x10101022
,
0x10101120
,
0x10101122
,
0x10101220
,
0x10101221
,
0x10111021
,
0x10111120
,
0x10111121
,
0x10111220
,
0x10111221
,
0x10121020
,
0x10121021
,
0x10121022
,
0x10121120
,
0x10121121
,
0x10121122
,
0x10121220
,
0x10121221
,
0x11101021
,
0x11101121
,
0x11101122
,
0x11101220
,
0x11101221
,
0x11101222
,
0x11111020
,
0x11111021
,
0x11111022
,
0x11111120
,
0x11111121
,
0x11111122
,
0x11111220
,
0x11111221
,
0x11111222
,
0x11121021
,
0x11121120
,
0x11121121
,
0x11121221
,
0x12101022
,
0x12101121
,
0x12101122
,
0x12101220
,
0x12101221
,
0x12101222
,
0x12111021
,
0x12111121
,
0x12111222
,
0x12121022
,
0x12121121
,
0x12121122
,
0x12121220
,
0x12121221
,
0x10102100
,
0x10102101
,
0x10102102
,
0x10102201
,
0x10112000
,
0x10112101
,
0x10112200
,
0x10122001
,
0x10122202
,
0x11102101
,
0x11102200
,
0x11102202
,
0x11112001
,
0x11112100
,
0x11112101
,
0x11112102
,
0x11112200
,
0x11112201
,
0x11122000
,
0x11122002
,
0x11122100
,
0x11122101
,
0x12102002
,
0x12102201
,
0x12112000
,
0x12112002
,
0x12112101
,
0x12112200
,
0x12122001
,
0x12122201
,
0x10102011
,
0x10102012
,
0x10102111
,
0x10102212
,
0x10112011
,
0x10112110
,
0x10112111
,
0x10112112
,
0x10112211
,
0x10122111
,
0x11102011
,
0x11102110
,
0x11102111
,
0x11102112
,
0x11102211
,
0x11112010
,
0x11112011
,
0x11112012
,
0x11112110
,
0x11112111
,
0x11112112
,
0x11112210
,
0x11112211
,
0x11112212
,
0x11122011
,
0x11122110
,
0x11122111
,
0x11122112
,
0x11122211
,
0x12102011
,
0x12102111
,
0x12102211
,
0x12112011
,
0x12112110
,
0x12112111
,
0x12112112
,
0x12112210
,
0x12112211
,
0x12122111
,
0x10102120
,
0x10102220
,
0x10112121
,
0x10112222
,
0x10122020
,
0x10122121
,
0x10122122
,
0x10122221
,
0x11102121
,
0x11102220
,
0x11102221
,
0x11112021
,
0x11112121
,
0x11112122
,
0x11112220
,
0x11112221
,
0x11122022
,
0x11122121
,
0x11122220
,
0x11122222
,
0x12102021
,
0x12102222
,
0x12112022
,
0x12112121
,
0x12112122
,
0x12112220
,
0x12112222
,
0x12122021
,
0x10200101
,
0x10210100
,
0x10210102
,
0x10210201
,
0x10220101
,
0x11200100
,
0x11210000
,
0x11210101
,
0x11210102
,
0x11210200
,
0x11210202
,
0x11220001
,
0x11220100
,
0x11220102
,
0x11220201
,
0x12200001
,
0x12210102
,
0x12220101
,
0x10200011
,
0x10200110
,
0x10200112
,
0x10200211
,
0x10210012
,
0x10210111
,
0x10220011
,
0x10220012
,
0x10220112
,
0x10220211
,
0x11200111
,
0x11200211
,
0x11210011
,
0x11210111
,
0x11210112
,
0x11210211
,
0x11220111
,
0x11220112
,
0x11220212
,
0x12200110
,
0x12200212
,
0x12210012
,
0x12210111
,
0x12220011
,
0x12220112
,
0x12220211
,
0x10210021
,
0x10210122
,
0x10210221
,
0x11200020
,
0x11200021
,
0x11200122
,
0x11210121
,
0x11210122
,
0x11210220
,
0x11220020
,
0x12200121
,
0x12210021
,
0x12210122
,
0x12220121
,
0x10211001
,
0x10211002
,
0x10211101
,
0x10211102
,
0x10211202
,
0x10221001
,
0x10221102
,
0x10221201
,
0x11201000
,
0x11201002
,
0x11201101
,
0x11201200
,
0x11201202
,
0x11211001
,
0x11211100
,
0x11211101
,
0x11211102
,
0x11211201
,
0x11211202
,
0x11221000
,
0x11221002
,
0x11221101
,
0x12201100
,
0x12201101
,
0x12201201
,
0x12211000
,
0x12211002
,
0x12211100
,
0x12211101
,
0x12211102
,
0x12211200
,
0x12211202
,
0x12221001
,
0x12221100
,
0x12221201
,
0x10201111
,
0x10201210
,
0x10201212
,
0x10211011
,
0x10211111
,
0x10211112
,
0x10211211
,
0x11201110
,
0x11201111
,
0x11201112
,
0x11201211
,
0x11211010
,
0x11211011
,
0x11211110
,
0x11211111
,
0x11211112
,
0x11211211
,
0x11221011
,
0x11221110
,
0x11221111
,
0x11221112
,
0x11221211
,
0x12201112
,
0x12201211
,
0x12201212
,
0x12211011
,
0x12211111
,
0x12211112
,
0x12211211
,
0x12211212
,
0x12221012
,
0x12221111
,
0x12221112
,
0x12221210
,
0x10201022
,
0x10201221
,
0x10211121
,
0x10221020
,
0x10221122
,
0x10221220
,
0x10221221
,
0x11201020
,
0x11201121
,
0x11201220
,
0x11201222
,
0x11211021
,
0x11211120
,
0x11211121
,
0x11211122
,
0x11211220
,
0x11211222
,
0x11221020
,
0x11221121
,
0x11221220
,
0x12201020
,
0x12201022
,
0x12201121
,
0x12201222
,
0x12211120
,
0x12211122
,
0x12211220
,
0x12211221
,
0x12221020
,
0x12221120
,
0x12221122
,
0x12221222
,
0x10212102
,
0x10212201
,
0x10222101
,
0x11202001
,
0x11212002
,
0x11212101
,
0x11212202
,
0x11222001
,
0x11222201
,
0x12202101
,
0x12212001
,
0x12212200
,
0x12222102
,
0x10202011
,
0x10202110
,
0x10212010
,
0x10212111
,
0x10222011
,
0x10222110
,
0x10222112
,
0x10222211
,
0x11202010
,
0x11202011
,
0x11202111
,
0x11202112
,
0x11202210
,
0x11212011
,
0x11212110
,
0x11212111
,
0x11212112
,
0x11212211
,
0x11222010
,
0x11222111
,
0x11222212
,
0x12202012
,
0x12202110
,
0x12202212
,
0x12212111
,
0x12222011
,
0x12222110
,
0x12222111
,
0x12222211
,
0x10212021
,
0x10212122
,
0x10212220
,
0x11202021
,
0x11202120
,
0x11202221
,
0x11212020
,
0x11212121
,
0x11212220
,
0x11212222
,
0x11222120
,
0x11222121
,
0x11222221
,
0x12202122
,
0x12212120
,
0x12212220
,
0x12212222
,
0x12222122
,
0x20000000
,
0x20000002
,
0x20000200
,
0x20000202
,
0x20020000
,
0x20020002
,
0x20020200
,
0x20020202
,
0x21000101
,
0x21010000
,
0x21010001
,
0x21010100
,
0x21010102
,
0x21010201
,
0x21020101
,
0x22000000
,
0x22000002
,
0x22000200
,
0x22000202
,
0x22010101
,
0x22020000
,
0x22020002
,
0x22020200
,
0x22020202
,
0x20000111
,
0x20010011
,
0x20010110
,
0x20010112
,
0x20010211
,
0x20020111
,
0x21000011
,
0x21000110
,
0x21000211
,
0x21010010
,
0x21010012
,
0x21010111
,
0x21010112
,
0x21010210
,
0x21010211
,
0x21020110
,
0x21020112
,
0x21020211
,
0x22000111
,
0x22000211
,
0x22010110
,
0x22010112
,
0x22010211
,
0x22020111
,
0x20000020
,
0x20000022
,
0x20000220
,
0x20000222
,
0x20010121
,
0x20020020
,
0x20020022
,
0x20020220
,
0x20020222
,
0x21010021
,
0x21010120
,
0x21010221
,
0x21020121
,
0x22000020
,
0x22000022
,
0x22000220
,
0x22000222
,
0x22010121
,
0x22020020
,
0x22020022
,
0x22020220
,
0x22020222
,
0x20011100
,
0x20011201
,
0x21001001
,
0x21001100
,
0x21011001
,
0x21011101
,
0x21011202
,
0x21021001
,
0x21021100
,
0x21021201
,
0x22011100
,
0x22011201
,
0x20001011
,
0x20001211
,
0x20011012
,
0x20011111
,
0x20011212
,
0x20021112
,
0x20021211
,
0x21001010
,
0x21001011
,
0x21001111
,
0x21001210
,
0x21011011
,
0x21011110
,
0x21011111
,
0x21011112
,
0x21011211
,
0x21011212
,
0x21021111
,
0x21021112
,
0x21021210
,
0x21021212
,
0x22001011
,
0x22001110
,
0x22001112
,
0x22001211
,
0x22011010
,
0x22011012
,
0x22011111
,
0x22011210
,
0x22021112
,
0x20011021
,
0x20011122
,
0x20011221
,
0x20021121
,
0x21001021
,
0x21001120
,
0x21001221
,
0x21001222
,
0x21011020
,
0x21011121
,
0x21011221
,
0x21011222
,
0x21021021
,
0x21021122
,
0x21021222
,
0x22001121
,
0x22011021
,
0x22011222
,
0x22021120
,
0x20002000
,
0x20002002
,
0x20002200
,
0x20002202
,
0x20012101
,
0x20022000
,
0x20022002
,
0x20022200
,
0x20022202
,
0x21002001
,
0x21002101
,
0x21012001
,
0x21012100
,
0x21012201
,
0x21022101
,
0x21022201
,
0x22002000
,
0x22002002
,
0x22002200
,
0x22002202
,
0x22012101
,
0x22022000
,
0x22022002
,
0x22022200
,
0x22022202
,
0x20002111
,
0x20002112
,
0x20012011
,
0x20012110
,
0x20012112
,
0x20022111
,
0x21002011
,
0x21002110
,
0x21002112
,
0x21002211
,
0x21012010
,
0x21012012
,
0x21012111
,
0x21012212
,
0x21022011
,
0x21022110
,
0x22002111
,
0x22012112
,
0x22012211
,
0x22022111
,
0x20002020
,
0x20002022
,
0x20002220
,
0x20002222
,
0x20012121
,
0x20022020
,
0x20022022
,
0x20022220
,
0x20022222
,
0x21002121
,
0x21012021
,
0x21012120
,
0x21012122
,
0x22002020
,
0x22002022
,
0x22002220
,
0x22002222
,
0x22012121
,
0x22022020
,
0x22022022
,
0x22022220
,
0x22022222
,
0x20100101
,
0x20110001
,
0x20110102
,
0x20110200
,
0x20110201
,
0x20120101
,
0x21100001
,
0x21100102
,
0x21100201
,
0x21110101
,
0x21110200
,
0x21110202
,
0x21120201
,
0x21120202
,
0x22100101
,
0x22110001
,
0x22110100
,
0x22110102
,
0x22110201
,
0x22120101
,
0x20100011
,
0x20100110
,
0x20100112
,
0x20100211
,
0x20110010
,
0x20110111
,
0x20110210
,
0x20110212
,
0x20120011
,
0x20120110
,
0x20120112
,
0x20120211
,
0x21100010
,
0x21100111
,
0x21110010
,
0x21110011
,
0x21110110
,
0x21110111
,
0x21110112
,
0x21110211
,
0x21120012
,
0x21120111
,
0x22100110
,
0x22100112
,
0x22110012
,
0x22110111
,
0x22110210
,
0x22120011
,
0x22120110
,
0x22120112
,
0x22120211
,
0x20100121
,
0x20110021
,
0x20110120
,
0x20110221
,
0x20120121
,
0x21100120
,
0x21100122
,
0x21100221
,
0x21110020
,
0x21110022
,
0x21110121
,
0x21110220
,
0x21120122
,
0x21120221
,
0x22100121
,
0x22110120
,
0x22110122
,
0x22120221
,
0x20101001
,
0x20101100
,
0x20101102
,
0x20111000
,
0x20111101
,
0x20111200
,
0x20121102
,
0x21101000
,
0x21101202
,
0x21111001
,
0x21111100
,
0x21111101
,
0x21111102
,
0x21111200
,
0x21111201
,
0x21121000
,
0x21121001
,
0x21121002
,
0x21121101
,
0x22101100
,
0x22101102
,
0x22111002
,
0x22111100
,
0x22111101
,
0x22111200
,
0x22121001
,
0x22121201
,
0x20101010
,
0x20101111
,
0x20101210
,
0x20101212
,
0x20111010
,
0x20111011
,
0x20111110
,
0x20111111
,
0x20111112
,
0x20111211
,
0x20121011
,
0x20121111
,
0x20121211
,
0x20121212
,
0x21101011
,
0x21101110
,
0x21101111
,
0x21101112
,
0x21101211
,
0x21111010
,
0x21111011
,
0x21111012
,
0x21111110
,
0x21111111
,
0x21111112
,
0x21111210
,
0x21111211
,
0x21111212
,
0x21121011
,
0x21121110
,
0x21121111
,
0x21121112
,
0x21121211
,
0x22101011
,
0x22101111
,
0x22101210
,
0x22111011
,
0x22111012
,
0x22111110
,
0x22111111
,
0x22111112
,
0x22111211
,
0x22111212
,
0x22121010
,
0x22121012
,
0x22121111
,
0x22121210
,
0x22121212
,
0x20101021
,
0x20101120
,
0x20111020
,
0x20111121
,
0x20111221
,
0x20121020
,
0x20121122
,
0x20121221
,
0x21101121
,
0x21101220
,
0x21101221
,
0x21111021
,
0x21111022
,
0x21111121
,
0x21111122
,
0x21111221
,
0x21121121
,
0x21121220
,
0x22101022
,
0x22101120
,
0x22101221
,
0x22101222
,
0x22111022
,
0x22111120
,
0x22111121
,
0x22121120
,
0x22121122
,
0x22121221
,
0x20102101
,
0x20112102
,
0x20112201
,
0x20122101
,
0x21102001
,
0x21102102
,
0x21112000
,
0x21112002
,
0x21112101
,
0x21112102
,
0x21112202
,
0x21122100
,
0x21122101
,
0x22102101
,
0x22112001
,
0x22112102
,
0x22112201
,
0x22122101
,
0x20102110
,
0x20102112
,
0x20102211
,
0x20112010
,
0x20112012
,
0x20112111
,
0x20112210
,
0x20112212
,
0x20122010
,
0x20122011
,
0x20122110
,
0x20122112
,
0x21102010
,
0x21102012
,
0x21102111
,
0x21102210
,
0x21102212
,
0x21112011
,
0x21112110
,
0x21112111
,
0x21112112
,
0x21112211
,
0x21122012
,
0x21122111
,
0x21122112
,
0x21122212
,
0x22102011
,
0x22102110
,
0x22112010
,
0x22112012
,
0x22112111
,
0x22112212
,
0x22122011
,
0x22122112
,
0x20102121
,
0x20112121
,
0x20122121
,
0x21102120
,
0x21102122
,
0x21102221
,
0x21112020
,
0x21112121
,
0x21112220
,
0x21122021
,
0x22102121
,
0x22112021
,
0x22112120
,
0x22112121
,
0x22112122
,
0x20200000
,
0x20200002
,
0x20200200
,
0x20200202
,
0x20210101
,
0x20220000
,
0x20220002
,
0x20220200
,
0x20220202
,
0x21200101
,
0x21210001
,
0x21210100
,
0x21210102
,
0x21210201
,
0x22200000
,
0x22200002
,
0x22200200
,
0x22200202
,
0x22210101
,
0x22220000
,
0x22220002
,
0x22220200
,
0x22220202
,
0x20200111
,
0x20200211
,
0x20210011
,
0x20210110
,
0x20210112
,
0x20210211
,
0x20210212
,
0x21200112
,
0x21200211
,
0x21210011
,
0x21210111
,
0x21210210
,
0x21210212
,
0x21220011
,
0x21220110
,
0x22200111
,
0x22210010
,
0x22210012
,
0x22210112
,
0x22210211
,
0x20200022
,
0x20200220
,
0x20200222
,
0x20210020
,
0x20210221
,
0x20220022
,
0x20220220
,
0x20220222
,
0x21200121
,
0x21210021
,
0x21210122
,
0x21210221
,
0x21220121
,
0x22200020
,
0x22200022
,
0x22200220
,
0x22200222
,
0x22210121
,
0x22220020
,
0x22220022
,
0x22220220
,
0x22220222
,
0x20211201
,
0x20221101
,
0x21201001
,
0x21201100
,
0x21211000
,
0x21211100
,
0x21211101
,
0x21211200
,
0x21211202
,
0x21221001
,
0x21221101
,
0x21221102
,
0x21221200
,
0x21221201
,
0x22201101
,
0x20201112
,
0x20201211
,
0x20211010
,
0x20211012
,
0x20211111
,
0x20211210
,
0x20221112
,
0x20221211
,
0x21201012
,
0x21201111
,
0x21211011
,
0x21211110
,
0x21211111
,
0x21211112
,
0x21211211
,
0x21221111
,
0x21221212
,
0x22201011
,
0x22201110
,
0x22201111
,
0x22201112
,
0x22201211
,
0x22211012
,
0x22211111
,
0x22211210
,
0x20201121
,
0x20211021
,
0x20211122
,
0x20211222
,
0x20221021
,
0x20221121
,
0x21201120
,
0x21201122
,
0x21201222
,
0x21211022
,
0x21211121
,
0x21211122
,
0x21211220
,
0x21221020
,
0x21221022
,
0x22201122
,
0x22211020
,
0x22211121
,
0x22211122
,
0x22211221
,
0x22221021
,
0x22221120
,
0x22221122
,
0x20202000
,
0x20202002
,
0x20202200
,
0x20202202
,
0x20222000
,
0x20222002
,
0x20222200
,
0x20222202
,
0x21212001
,
0x21212100
,
0x21212102
,
0x21212201
,
0x22202000
,
0x22202002
,
0x22202200
,
0x22202202
,
0x22212101
,
0x22222000
,
0x22222002
,
0x22222200
,
0x22222202
,
0x20202111
,
0x20212110
,
0x20212211
,
0x20222011
,
0x20222111
,
0x21202011
,
0x21212010
,
0x21212111
,
0x21212212
,
0x21222011
,
0x21222112
,
0x21222211
,
0x22212010
,
0x22212112
,
0x20202020
,
0x20202022
,
0x20202220
,
0x20202222
,
0x20222020
,
0x20222022
,
0x20222220
,
0x20222222
,
0x21212021
,
0x21212120
,
0x21212122
,
0x22202020
,
0x22202022
,
0x22202220
,
0x22202222
,
0x22212121
,
0x22222020
,
0x22222022
,
0x22222220
,
0x22222222
,
};
static
const
__device__
uint8_t
ksigns_iq2xs
[
128
]
=
{
0
,
129
,
130
,
3
,
132
,
5
,
6
,
135
,
136
,
9
,
10
,
139
,
12
,
141
,
142
,
15
,
144
,
17
,
18
,
147
,
20
,
149
,
150
,
23
,
24
,
153
,
154
,
27
,
156
,
29
,
30
,
159
,
160
,
33
,
34
,
163
,
36
,
165
,
166
,
39
,
40
,
169
,
170
,
43
,
172
,
45
,
46
,
175
,
48
,
177
,
178
,
51
,
180
,
53
,
54
,
183
,
184
,
57
,
58
,
187
,
60
,
189
,
190
,
63
,
192
,
65
,
66
,
195
,
68
,
197
,
198
,
71
,
72
,
201
,
202
,
75
,
204
,
77
,
78
,
207
,
80
,
209
,
210
,
83
,
212
,
85
,
86
,
215
,
216
,
89
,
90
,
219
,
92
,
221
,
222
,
95
,
96
,
225
,
226
,
99
,
228
,
101
,
102
,
231
,
232
,
105
,
106
,
235
,
108
,
237
,
238
,
111
,
240
,
113
,
114
,
243
,
116
,
245
,
246
,
119
,
120
,
249
,
250
,
123
,
252
,
125
,
126
,
255
,
};
static
const
__device__
uint64_t
ksigns64
[
128
]
=
{
0x0000000000000000
,
0xff000000000000ff
,
0xff0000000000ff00
,
0x000000000000ffff
,
0xff00000000ff0000
,
0x0000000000ff00ff
,
0x0000000000ffff00
,
0xff00000000ffffff
,
0xff000000ff000000
,
0x00000000ff0000ff
,
0x00000000ff00ff00
,
0xff000000ff00ffff
,
0x00000000ffff0000
,
0xff000000ffff00ff
,
0xff000000ffffff00
,
0x00000000ffffffff
,
0xff0000ff00000000
,
0x000000ff000000ff
,
0x000000ff0000ff00
,
0xff0000ff0000ffff
,
0x000000ff00ff0000
,
0xff0000ff00ff00ff
,
0xff0000ff00ffff00
,
0x000000ff00ffffff
,
0x000000ffff000000
,
0xff0000ffff0000ff
,
0xff0000ffff00ff00
,
0x000000ffff00ffff
,
0xff0000ffffff0000
,
0x000000ffffff00ff
,
0x000000ffffffff00
,
0xff0000ffffffffff
,
0xff00ff0000000000
,
0x0000ff00000000ff
,
0x0000ff000000ff00
,
0xff00ff000000ffff
,
0x0000ff0000ff0000
,
0xff00ff0000ff00ff
,
0xff00ff0000ffff00
,
0x0000ff0000ffffff
,
0x0000ff00ff000000
,
0xff00ff00ff0000ff
,
0xff00ff00ff00ff00
,
0x0000ff00ff00ffff
,
0xff00ff00ffff0000
,
0x0000ff00ffff00ff
,
0x0000ff00ffffff00
,
0xff00ff00ffffffff
,
0x0000ffff00000000
,
0xff00ffff000000ff
,
0xff00ffff0000ff00
,
0x0000ffff0000ffff
,
0xff00ffff00ff0000
,
0x0000ffff00ff00ff
,
0x0000ffff00ffff00
,
0xff00ffff00ffffff
,
0xff00ffffff000000
,
0x0000ffffff0000ff
,
0x0000ffffff00ff00
,
0xff00ffffff00ffff
,
0x0000ffffffff0000
,
0xff00ffffffff00ff
,
0xff00ffffffffff00
,
0x0000ffffffffffff
,
0xffff000000000000
,
0x00ff0000000000ff
,
0x00ff00000000ff00
,
0xffff00000000ffff
,
0x00ff000000ff0000
,
0xffff000000ff00ff
,
0xffff000000ffff00
,
0x00ff000000ffffff
,
0x00ff0000ff000000
,
0xffff0000ff0000ff
,
0xffff0000ff00ff00
,
0x00ff0000ff00ffff
,
0xffff0000ffff0000
,
0x00ff0000ffff00ff
,
0x00ff0000ffffff00
,
0xffff0000ffffffff
,
0x00ff00ff00000000
,
0xffff00ff000000ff
,
0xffff00ff0000ff00
,
0x00ff00ff0000ffff
,
0xffff00ff00ff0000
,
0x00ff00ff00ff00ff
,
0x00ff00ff00ffff00
,
0xffff00ff00ffffff
,
0xffff00ffff000000
,
0x00ff00ffff0000ff
,
0x00ff00ffff00ff00
,
0xffff00ffff00ffff
,
0x00ff00ffffff0000
,
0xffff00ffffff00ff
,
0xffff00ffffffff00
,
0x00ff00ffffffffff
,
0x00ffff0000000000
,
0xffffff00000000ff
,
0xffffff000000ff00
,
0x00ffff000000ffff
,
0xffffff0000ff0000
,
0x00ffff0000ff00ff
,
0x00ffff0000ffff00
,
0xffffff0000ffffff
,
0xffffff00ff000000
,
0x00ffff00ff0000ff
,
0x00ffff00ff00ff00
,
0xffffff00ff00ffff
,
0x00ffff00ffff0000
,
0xffffff00ffff00ff
,
0xffffff00ffffff00
,
0x00ffff00ffffffff
,
0xffffffff00000000
,
0x00ffffff000000ff
,
0x00ffffff0000ff00
,
0xffffffff0000ffff
,
0x00ffffff00ff0000
,
0xffffffff00ff00ff
,
0xffffffff00ffff00
,
0x00ffffff00ffffff
,
0x00ffffffff000000
,
0xffffffffff0000ff
,
0xffffffffff00ff00
,
0x00ffffffff00ffff
,
0xffffffffffff0000
,
0x00ffffffffff00ff
,
0x00ffffffffffff00
,
0xffffffffffffffff
,
};
static
const
__device__
uint8_t
kmask_iq2xs
[
8
]
=
{
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
};
static
const
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
typedef
half
dfloat
;
// dequantize float
typedef
half2
dfloat2
;
typedef
void
(
*
dequantize_kernel_t
)(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
);
typedef
void
(
*
to_fp16_cuda_t
)(
const
void
*
__restrict__
x
,
dfloat
*
__restrict__
y
,
int
k
,
cudaStream_t
stream
);
typedef
float
(
*
vec_dot_q_cuda_t
)(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
);
typedef
void
(
*
allocate_tiles_cuda_t
)(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
);
typedef
void
(
*
load_tiles_cuda_t
)(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
);
typedef
float
(
*
vec_dot_q_mul_mat_cuda_t
)(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ms
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
);
// Utility function
#if defined(USE_ROCM)
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef
int8_t
int8x4_t
__attribute__
((
ext_vector_type
(
4
)));
static
__device__
__forceinline__
int
__vsubss4
(
const
int
a
,
const
int
b
)
{
const
int8x4_t
va
=
reinterpret_cast
<
const
int8x4_t
&>
(
a
);
const
int8x4_t
vb
=
reinterpret_cast
<
const
int8x4_t
&>
(
b
);
#if __has_builtin(__builtin_elementwise_sub_sat)
const
int8x4_t
c
=
__builtin_elementwise_sub_sat
(
va
,
vb
);
return
reinterpret_cast
<
const
int
&>
(
c
);
#else
int8x4_t
c
;
int16_t
tmp
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp
=
va
[
i
]
-
vb
[
i
];
if
(
tmp
>
std
::
numeric_limits
<
int8_t
>::
max
())
tmp
=
std
::
numeric_limits
<
int8_t
>::
max
();
if
(
tmp
<
std
::
numeric_limits
<
int8_t
>::
min
())
tmp
=
std
::
numeric_limits
<
int8_t
>::
min
();
c
[
i
]
=
tmp
;
}
return
reinterpret_cast
<
int
&>
(
c
);
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
static
__device__
__forceinline__
int
__dp4a
(
const
int
a
,
const
int
b
,
int
c
)
{
#if __has_builtin(__builtin_amdgcn_sdot4)
c
=
__builtin_amdgcn_sdot4
(
a
,
b
,
c
,
false
);
#else
const
int8x4_t
va
=
reinterpret_cast
<
const
int8x4_t
&>
(
a
);
const
int8x4_t
vb
=
reinterpret_cast
<
const
int8x4_t
&>
(
b
);
c
+=
va
[
0
]
*
vb
[
0
]
+
va
[
1
]
*
vb
[
1
]
+
va
[
2
]
*
vb
[
2
]
+
va
[
3
]
*
vb
[
3
];
#endif
return
c
;
}
#endif // defined(USE_ROCM)
csrc/quantization/gguf/gguf_kernel.cu
0 → 100644
View file @
ad385667
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
// Q8 gemv
static
__global__
void
quantize_q8_1
(
const
half
*
__restrict__
x
,
void
*
__restrict__
vy
,
const
int
kx
,
const
int
kx_padded
)
{
const
int
ix
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
ix
>=
kx_padded
)
{
return
;
}
const
int
iy
=
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
;
const
int
i_padded
=
iy
*
kx_padded
+
ix
;
block_q8_1
*
y
=
(
block_q8_1
*
)
vy
;
const
int
ib
=
i_padded
/
QK8_1
;
// block index
const
int
iqs
=
i_padded
%
QK8_1
;
// quant index
const
float
xi
=
ix
<
kx
?
__half2float
(
x
[
iy
*
kx
+
ix
])
:
0.0
f
;
float
amax
=
fabsf
(
xi
);
float
sum
=
xi
;
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
{
amax
=
fmaxf
(
amax
,
__shfl_xor_sync
(
0xffffffff
,
amax
,
mask
,
32
));
sum
+=
__shfl_xor_sync
(
0xffffffff
,
sum
,
mask
,
32
);
}
const
float
d
=
amax
/
127
;
const
int8_t
q
=
amax
==
0.0
f
?
0
:
roundf
(
xi
/
d
);
y
[
ib
].
qs
[
iqs
]
=
q
;
if
(
iqs
>
0
)
{
return
;
}
y
[
ib
].
ds
.
x
=
__float2half
(
d
);
y
[
ib
].
ds
.
y
=
__float2half
(
sum
);
}
static
void
quantize_row_q8_1_cuda
(
const
half
*
x
,
void
*
vy
,
const
int
kx
,
const
int
ky
,
cudaStream_t
stream
)
{
const
int64_t
kx_padded
=
(
kx
+
512
-
1
)
/
512
*
512
;
const
int
block_num_x
=
(
kx_padded
+
CUDA_QUANTIZE_BLOCK_SIZE
-
1
)
/
CUDA_QUANTIZE_BLOCK_SIZE
;
const
dim3
num_blocks
(
block_num_x
,
ky
,
1
);
const
dim3
block_size
(
CUDA_DEQUANTIZE_BLOCK_SIZE
,
1
,
1
);
quantize_q8_1
<<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
x
,
vy
,
kx
,
kx_padded
);
}
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
// quant weight
int64_t
type
,
int64_t
m
,
int64_t
n
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
W
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat16
).
device
(
W
.
device
());
at
::
Tensor
DW
=
torch
::
empty
({
m
,
n
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
to_fp16_cuda_t
to_fp16_cuda
=
ggml_get_to_fp16_cuda
(
type
);
to_fp16_cuda
((
void
*
)
W
.
data_ptr
(),
(
half
*
)
DW
.
data_ptr
(),
m
*
n
,
stream
);
return
DW
;
}
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
// quant weight
torch
::
Tensor
X
,
// input
int64_t
type
,
int64_t
row
)
{
int
col
=
X
.
sizes
()[
1
];
const
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat16
).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
empty
({
1
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
1
,
padded
/
32
*
9
},
options
);
quantize_row_q8_1_cuda
((
half
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
1
,
stream
);
switch
(
type
)
{
case
2
:
mul_mat_vec_q4_0_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
3
:
mul_mat_vec_q4_1_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
6
:
mul_mat_vec_q5_0_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
7
:
mul_mat_vec_q5_1_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
8
:
mul_mat_vec_q8_0_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
10
:
mul_mat_vec_q2_K_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
11
:
mul_mat_vec_q3_K_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
12
:
mul_mat_vec_q4_K_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
13
:
mul_mat_vec_q5_K_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
14
:
mul_mat_vec_q6_K_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
16
:
mul_mat_vec_iq2_xxs_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
17
:
mul_mat_vec_iq2_xs_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
18
:
mul_mat_vec_iq3_xxs_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
19
:
mul_mat_vec_iq1_s_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
20
:
mul_mat_vec_iq4_nl_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
21
:
mul_mat_vec_iq3_s_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
22
:
mul_mat_vec_iq2_s_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
23
:
mul_mat_vec_iq4_xs_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
case
29
:
mul_mat_vec_iq1_m_q8_1_cuda
((
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
stream
);
break
;
}
return
Y
;
}
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
// quant weight
torch
::
Tensor
X
,
// input
int64_t
type
,
int64_t
row
)
{
int
col
=
X
.
sizes
()[
1
];
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
int
batch
=
X
.
sizes
()[
0
];
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat16
).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
empty
({
batch
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
batch
,
padded
/
32
*
9
},
options
);
quantize_row_q8_1_cuda
((
half
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
batch
,
stream
);
switch
(
type
)
{
case
2
:
ggml_mul_mat_q4_0_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
3
:
ggml_mul_mat_q4_1_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
6
:
ggml_mul_mat_q5_0_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
7
:
ggml_mul_mat_q5_1_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
8
:
ggml_mul_mat_q8_0_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
10
:
ggml_mul_mat_q2_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
11
:
ggml_mul_mat_q3_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
12
:
ggml_mul_mat_q4_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
13
:
ggml_mul_mat_q5_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
case
14
:
ggml_mul_mat_q6_K_q8_1_cuda
(
(
void
*
)
W
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
(
half
*
)
Y
.
data_ptr
(),
col
,
row
,
batch
,
padded
,
row
,
stream
);
break
;
}
return
Y
;
}
csrc/quantization/gguf/mmq.cuh
0 → 100644
View file @
ad385667
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
template
<
int
qk
,
int
qr
,
int
qi
,
bool
need_sum
,
typename
block_q_t
,
int
mmq_x
,
int
mmq_y
,
int
nwarps
,
allocate_tiles_cuda_t
allocate_tiles
,
load_tiles_cuda_t
load_tiles
,
int
vdr
,
vec_dot_q_mul_mat_cuda_t
vec_dot
>
static
__device__
__forceinline__
void
mul_mat_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
block_q_t
*
x
=
(
const
block_q_t
*
)
vx
;
const
block_q8_1
*
y
=
(
const
block_q8_1
*
)
vy
;
const
int
blocks_per_row_x
=
ncols_x
/
qk
;
const
int
blocks_per_col_y
=
nrows_y
/
QK8_1
;
const
int
blocks_per_warp
=
WARP_SIZE
/
qi
;
const
int
&
ncols_dst
=
ncols_y
;
const
int
row_dst_0
=
blockIdx
.
x
*
mmq_y
;
const
int
&
row_x_0
=
row_dst_0
;
const
int
col_dst_0
=
blockIdx
.
y
*
mmq_x
;
const
int
&
col_y_0
=
col_dst_0
;
int
*
tile_x_ql
=
nullptr
;
half2
*
tile_x_dm
=
nullptr
;
int
*
tile_x_qh
=
nullptr
;
int
*
tile_x_sc
=
nullptr
;
allocate_tiles
(
&
tile_x_ql
,
&
tile_x_dm
,
&
tile_x_qh
,
&
tile_x_sc
);
__shared__
int
tile_y_qs
[
mmq_x
*
WARP_SIZE
];
__shared__
half2
tile_y_ds
[
mmq_x
*
WARP_SIZE
/
QI8_1
];
float
sum
[
mmq_y
/
WARP_SIZE
][
mmq_x
/
nwarps
]
=
{{
0.0
f
}};
for
(
int
ib0
=
0
;
ib0
<
blocks_per_row_x
;
ib0
+=
blocks_per_warp
)
{
load_tiles
(
x
+
row_x_0
*
blocks_per_row_x
+
ib0
,
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
threadIdx
.
y
,
nrows_x
-
row_x_0
-
1
,
threadIdx
.
x
,
blocks_per_row_x
);
#pragma unroll
for
(
int
ir
=
0
;
ir
<
qr
;
++
ir
)
{
const
int
kqs
=
ir
*
WARP_SIZE
+
threadIdx
.
x
;
const
int
kbxd
=
kqs
/
QI8_1
;
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_x
;
i
+=
nwarps
)
{
const
int
col_y_eff
=
min
(
col_y_0
+
threadIdx
.
y
+
i
,
ncols_y
-
1
);
// to prevent out-of-bounds memory accesses
const
block_q8_1
*
by0
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
kbxd
];
const
int
index_y
=
(
threadIdx
.
y
+
i
)
*
WARP_SIZE
+
kqs
%
WARP_SIZE
;
tile_y_qs
[
index_y
]
=
get_int_from_int8_aligned
(
by0
->
qs
,
threadIdx
.
x
%
QI8_1
);
}
#pragma unroll
for
(
int
ids0
=
0
;
ids0
<
mmq_x
;
ids0
+=
nwarps
*
QI8_1
)
{
const
int
ids
=
(
ids0
+
threadIdx
.
y
*
QI8_1
+
threadIdx
.
x
/
(
WARP_SIZE
/
QI8_1
))
%
mmq_x
;
const
int
kby
=
threadIdx
.
x
%
(
WARP_SIZE
/
QI8_1
);
const
int
col_y_eff
=
min
(
col_y_0
+
ids
,
ncols_y
-
1
);
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
const
half2
*
dsi_src
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
ib0
*
(
qk
/
QK8_1
)
+
ir
*
(
WARP_SIZE
/
QI8_1
)
+
kby
].
ds
;
half2
*
dsi_dst
=
&
tile_y_ds
[
ids
*
(
WARP_SIZE
/
QI8_1
)
+
kby
];
if
(
need_sum
)
{
*
dsi_dst
=
*
dsi_src
;
}
else
{
float
*
dfi_dst
=
(
float
*
)
dsi_dst
;
*
dfi_dst
=
__low2float
(
*
dsi_src
);
}
}
__syncthreads
();
// #pragma unroll // unrolling this loop causes too much register pressure
for
(
int
k
=
ir
*
WARP_SIZE
/
qr
;
k
<
(
ir
+
1
)
*
WARP_SIZE
/
qr
;
k
+=
vdr
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
)
{
sum
[
i
/
WARP_SIZE
][
j
/
nwarps
]
+=
vec_dot
(
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
tile_y_qs
,
tile_y_ds
,
threadIdx
.
x
+
i
,
threadIdx
.
y
+
j
,
k
);
}
}
}
__syncthreads
();
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
const
int
col_dst
=
col_dst_0
+
j
+
threadIdx
.
y
;
if
(
col_dst
>=
ncols_dst
)
{
return
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE
)
{
const
int
row_dst
=
row_dst_0
+
threadIdx
.
x
+
i
;
if
(
row_dst
>=
nrows_dst
)
{
continue
;
}
dst
[
col_dst
*
nrows_dst
+
row_dst
]
=
__float2half
(
sum
[
i
/
WARP_SIZE
][
j
/
nwarps
]);
}
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_0 64
#define MMQ_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define MMQ_X_Q4_0 4
#define MMQ_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_0
,
2
)
#endif
mul_mat_q4_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q4_0
;
const
int
mmq_y
=
MMQ_Y_Q4_0
;
const
int
nwarps
=
NWARPS_Q4_0
;
mul_mat_q
<
QK4_0
,
QR4_0
,
QI4_0
,
true
,
block_q4_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_0
<
mmq_y
>
,
load_tiles_q4_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_0_Q8_1_MMQ
,
vec_dot_q4_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q4_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
int
mmq_x
=
MMQ_X_Q4_0
;
int
mmq_y
=
MMQ_Y_Q4_0
;
int
nwarps
=
NWARPS_Q4_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q4_0
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q4_0
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_1 64
#define MMQ_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define MMQ_X_Q4_1 4
#define MMQ_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_1
,
2
)
#endif
mul_mat_q4_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q4_1
;
const
int
mmq_y
=
MMQ_Y_Q4_1
;
const
int
nwarps
=
NWARPS_Q4_1
;
mul_mat_q
<
QK4_1
,
QR4_1
,
QI4_1
,
true
,
block_q4_1
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_1
<
mmq_y
>
,
load_tiles_q4_1
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_1_Q8_1_MMQ
,
vec_dot_q4_1_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q4_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
int
mmq_x
=
MMQ_X_Q4_1
;
int
mmq_y
=
MMQ_Y_Q4_1
;
int
nwarps
=
NWARPS_Q4_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q4_1
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q4_1
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_0 64
#define MMQ_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define MMQ_X_Q5_0 4
#define MMQ_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_0
,
2
)
#endif
mul_mat_q5_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q5_0
;
const
int
mmq_y
=
MMQ_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
mul_mat_q
<
QK5_0
,
QR5_0
,
QI5_0
,
false
,
block_q5_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_0
<
mmq_y
>
,
load_tiles_q5_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_0_Q8_1_MMQ
,
vec_dot_q5_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q5_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_0
;
const
int
mmq_y
=
MMQ_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q5_0
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q5_0
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_1 64
#define MMQ_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define MMQ_X_Q5_1 4
#define MMQ_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_1
,
2
)
#endif
mul_mat_q5_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q5_1
;
const
int
mmq_y
=
MMQ_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
mul_mat_q
<
QK5_1
,
QR5_1
,
QI5_1
,
true
,
block_q5_1
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_1
<
mmq_y
>
,
load_tiles_q5_1
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_1_Q8_1_MMQ
,
vec_dot_q5_1_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q5_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_1
;
const
int
mmq_y
=
MMQ_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q5_1
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q5_1
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q8_0 64
#define MMQ_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define MMQ_X_Q8_0 4
#define MMQ_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q8_0
,
2
)
#endif
mul_mat_q8_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q8_0
;
const
int
mmq_y
=
MMQ_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
mul_mat_q
<
QK8_0
,
QR8_0
,
QI8_0
,
false
,
block_q8_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q8_0
<
mmq_y
>
,
load_tiles_q8_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q8_0_Q8_1_MMQ
,
vec_dot_q8_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q8_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q8_0
;
const
int
mmq_y
=
MMQ_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q8_0
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q8_0
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q2_K 64
#define MMQ_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define MMQ_X_Q2_K 4
#define MMQ_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q2_K
,
2
)
#endif
mul_mat_q2_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q2_K
;
const
int
mmq_y
=
MMQ_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
mul_mat_q
<
QK_K
,
QR2_K
,
QI2_K
,
false
,
block_q2_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q2_K
<
mmq_y
>
,
load_tiles_q2_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q2_K_Q8_1_MMQ
,
vec_dot_q2_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q2_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q2_K
;
const
int
mmq_y
=
MMQ_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q2_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q2_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q3_K 64
#define MMQ_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define MMQ_X_Q3_K 4
#define MMQ_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q3_K
,
2
)
#endif
mul_mat_q3_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q3_K
;
const
int
mmq_y
=
MMQ_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
mul_mat_q
<
QK_K
,
QR3_K
,
QI3_K
,
false
,
block_q3_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q3_K
<
mmq_y
>
,
load_tiles_q3_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q3_K_Q8_1_MMQ
,
vec_dot_q3_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q3_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q3_K
;
const
int
mmq_y
=
MMQ_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q3_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q3_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_K 64
#define MMQ_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define MMQ_X_Q4_K 4
#define MMQ_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q4_K
,
2
)
#endif
mul_mat_q4_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q4_K
;
const
int
mmq_y
=
MMQ_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
mul_mat_q
<
QK_K
,
QR4_K
,
QI4_K
,
true
,
block_q4_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_K
<
mmq_y
>
,
load_tiles_q4_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_K_Q8_1_MMQ
,
vec_dot_q4_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q4_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q4_K
;
const
int
mmq_y
=
MMQ_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q4_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q4_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_K 64
#define MMQ_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define MMQ_X_Q5_K 4
#define MMQ_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q5_K
,
2
)
#endif
mul_mat_q5_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q5_K
;
const
int
mmq_y
=
MMQ_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
mul_mat_q
<
QK_K
,
QR5_K
,
QI5_K
,
true
,
block_q5_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_K
<
mmq_y
>
,
load_tiles_q5_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_K_Q8_1_MMQ
,
vec_dot_q5_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q5_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_K
;
const
int
mmq_y
=
MMQ_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q5_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q5_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q6_K 64
#define MMQ_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define MMQ_X_Q6_K 4
#define MMQ_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
template
<
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE
*
NWARPS_Q6_K
,
2
)
#endif
mul_mat_q6_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
const
int
mmq_x
=
MMQ_X_Q6_K
;
const
int
mmq_y
=
MMQ_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
mul_mat_q
<
QK_K
,
QR6_K
,
QI6_K
,
false
,
block_q6_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q6_K
<
mmq_y
>
,
load_tiles_q6_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q6_K_Q8_1_MMQ
,
vec_dot_q6_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
static
void
ggml_mul_mat_q6_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q6_K
;
const
int
mmq_y
=
MMQ_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
ncols_y
+
mmq_x
-
1
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
const
bool
need_check
=
false
;
mul_mat_q6_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
else
{
const
bool
need_check
=
true
;
mul_mat_q6_K
<
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
);
}
}
csrc/quantization/gguf/mmvq.cuh
0 → 100644
View file @
ad385667
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template
<
int
qk
,
int
qi
,
typename
block_q_t
,
int
vdr
,
vec_dot_q_cuda_t
vec_dot_q_cuda
>
static
__global__
void
mul_mat_vec_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
half
*
__restrict__
dst
,
const
int
ncols
,
const
int
nrows
)
{
const
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
row
>=
nrows
)
{
return
;
}
const
int
blocks_per_row
=
ncols
/
qk
;
const
int
blocks_per_warp
=
vdr
*
WARP_SIZE
/
qi
;
// partial sum for each thread
float
tmp
=
0.0
f
;
const
block_q_t
*
x
=
(
const
block_q_t
*
)
vx
;
const
block_q8_1
*
y
=
(
const
block_q8_1
*
)
vy
;
for
(
int
i
=
threadIdx
.
x
/
(
qi
/
vdr
);
i
<
blocks_per_row
;
i
+=
blocks_per_warp
)
{
const
int
ibx
=
row
*
blocks_per_row
+
i
;
// x block index
const
int
iby
=
i
*
(
qk
/
QK8_1
);
// y block index that aligns with ibx
const
int
iqs
=
vdr
*
(
threadIdx
.
x
%
(
qi
/
vdr
));
// x block quant index when casting the quants to int
tmp
+=
vec_dot_q_cuda
(
&
x
[
ibx
],
&
y
[
iby
],
iqs
);
}
// sum up partial sums and write back result
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
{
tmp
+=
__shfl_xor_sync
(
0xffffffff
,
tmp
,
mask
,
32
);
}
if
(
threadIdx
.
x
==
0
)
{
dst
[
row
]
=
__float2half
(
tmp
);
}
}
static
void
mul_mat_vec_q4_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK4_0
,
QI4_0
,
block_q4_0
,
VDR_Q4_0_Q8_1_MMVQ
,
vec_dot_q4_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q4_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK4_0
,
QI4_1
,
block_q4_1
,
VDR_Q4_1_Q8_1_MMVQ
,
vec_dot_q4_1_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q5_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK5_0
,
QI5_0
,
block_q5_0
,
VDR_Q5_0_Q8_1_MMVQ
,
vec_dot_q5_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q5_1_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK5_1
,
QI5_1
,
block_q5_1
,
VDR_Q5_1_Q8_1_MMVQ
,
vec_dot_q5_1_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q8_0_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK8_0
,
QI8_0
,
block_q8_0
,
VDR_Q8_0_Q8_1_MMVQ
,
vec_dot_q8_0_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q2_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI2_K
,
block_q2_K
,
VDR_Q2_K_Q8_1_MMVQ
,
vec_dot_q2_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q3_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI3_K
,
block_q3_K
,
VDR_Q3_K_Q8_1_MMVQ
,
vec_dot_q3_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q4_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI4_K
,
block_q4_K
,
VDR_Q4_K_Q8_1_MMVQ
,
vec_dot_q4_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q5_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI5_K
,
block_q5_K
,
VDR_Q5_K_Q8_1_MMVQ
,
vec_dot_q5_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_q6_K_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI6_K
,
block_q6_K
,
VDR_Q6_K_Q8_1_MMVQ
,
vec_dot_q6_K_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq2_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI2_XXS
,
block_iq2_xxs
,
1
,
vec_dot_iq2_xxs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq2_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI2_XS
,
block_iq2_xs
,
1
,
vec_dot_iq2_xs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq2_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI2_S
,
block_iq2_s
,
1
,
vec_dot_iq2_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq3_xxs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI3_XXS
,
block_iq3_xxs
,
1
,
vec_dot_iq3_xxs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq1_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI1_S
,
block_iq1_s
,
1
,
vec_dot_iq1_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq1_m_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI1_M
,
block_iq1_m
,
1
,
vec_dot_iq1_m_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq4_nl_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK4_NL
,
QI4_NL
,
block_iq4_nl
,
VDR_Q4_0_Q8_1_MMVQ
,
vec_dot_iq4_nl_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq4_xs_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI4_XS
,
block_iq4_xs
,
1
,
vec_dot_iq4_xs_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
static
void
mul_mat_vec_iq3_s_q8_1_cuda
(
const
void
*
vx
,
const
void
*
vy
,
half
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
mul_mat_vec_q
<
QK_K
,
QI3_XS
,
block_iq3_s
,
1
,
vec_dot_iq3_s_q8_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols
,
nrows
);
}
csrc/quantization/gguf/vecdotq.cuh
0 → 100644
View file @
ad385667
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh
// and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
static
__device__
__forceinline__
int
get_int_b2
(
const
void
*
x
,
const
int
&
i32
)
{
const
uint16_t
*
x16
=
(
const
uint16_t
*
)
x
;
// assume at least 2 byte alignment
int
x32
=
x16
[
2
*
i32
+
0
]
<<
0
;
x32
|=
x16
[
2
*
i32
+
1
]
<<
16
;
return
x32
;
}
static
__device__
__forceinline__
int
get_int_b4
(
const
void
*
x
,
const
int
&
i32
)
{
return
((
const
int
*
)
x
)[
i32
];
// assume at least 4 byte alignment
}
static
__device__
__forceinline__
int
get_int_from_int8
(
const
int8_t
*
x8
,
const
int
&
i32
)
{
const
uint16_t
*
x16
=
(
const
uint16_t
*
)
(
x8
+
sizeof
(
int
)
*
i32
);
// assume at least 2 byte alignment
int
x32
=
0
;
x32
|=
x16
[
0
]
<<
0
;
x32
|=
x16
[
1
]
<<
16
;
return
x32
;
}
static
__device__
__forceinline__
int
get_int_from_uint8
(
const
uint8_t
*
x8
,
const
int
&
i32
)
{
const
uint16_t
*
x16
=
(
const
uint16_t
*
)
(
x8
+
sizeof
(
int
)
*
i32
);
// assume at least 2 byte alignment
int
x32
=
0
;
x32
|=
x16
[
0
]
<<
0
;
x32
|=
x16
[
1
]
<<
16
;
return
x32
;
}
static
__device__
__forceinline__
int
get_int_from_int8_aligned
(
const
int8_t
*
x8
,
const
int
&
i32
)
{
return
*
((
const
int
*
)
(
x8
+
sizeof
(
int
)
*
i32
));
// assume at least 4 byte alignment
}
static
__device__
__forceinline__
int
get_int_from_uint8_aligned
(
const
uint8_t
*
x8
,
const
int
&
i32
)
{
return
*
((
const
int
*
)
(
x8
+
sizeof
(
int
)
*
i32
));
// assume at least 4 byte alignment
}
#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q4_0_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
float
&
d4
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
const
int
vi0
=
(
v
[
i
]
>>
0
)
&
0x0F0F0F0F
;
const
int
vi1
=
(
v
[
i
]
>>
4
)
&
0x0F0F0F0F
;
// SIMD dot product of quantized values
sumi
=
__dp4a
(
vi0
,
u
[
2
*
i
+
0
],
sumi
);
sumi
=
__dp4a
(
vi1
,
u
[
2
*
i
+
1
],
sumi
);
}
const
float2
ds8f
=
__half22float2
(
ds8
);
// second part effectively subtracts 8 from each quant value
return
d4
*
(
sumi
*
ds8f
.
x
-
(
8
*
vdr
/
QI4_0
)
*
ds8f
.
y
);
#endif
}
#define VDR_Q4_1_Q8_1_MMVQ 2
#define VDR_Q4_1_Q8_1_MMQ 4
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q4_1_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
half2
&
dm4
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
const
int
vi0
=
(
v
[
i
]
>>
0
)
&
0x0F0F0F0F
;
const
int
vi1
=
(
v
[
i
]
>>
4
)
&
0x0F0F0F0F
;
// SIMD dot product of quantized values
sumi
=
__dp4a
(
vi0
,
u
[
2
*
i
+
0
],
sumi
);
sumi
=
__dp4a
(
vi1
,
u
[
2
*
i
+
1
],
sumi
);
}
const
float2
tmp
=
__half22float2
(
__hmul2
(
dm4
,
ds8
));
const
float
d4d8
=
tmp
.
x
;
const
float
m4s8
=
tmp
.
y
;
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
return
sumi
*
d4d8
+
m4s8
/
(
QI8_1
/
(
vdr
*
QR4_1
));
#endif
}
#define VDR_Q5_0_Q8_1_MMVQ 2
#define VDR_Q5_0_Q8_1_MMQ 4
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q5_0_q8_1_impl
(
const
int
*
vl
,
const
int
*
vh
,
const
int
*
u
,
const
float
&
d5
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
int
vi0
=
(
vl
[
i
]
>>
0
)
&
0x0F0F0F0F
;
// lower 4 qs bits, still need qh as 5th bits
vi0
|=
(
vh
[
i
]
<<
4
)
&
0x00000010
;
// 0 -> 4
vi0
|=
(
vh
[
i
]
<<
11
)
&
0x00001000
;
// 1 -> 12
vi0
|=
(
vh
[
i
]
<<
18
)
&
0x00100000
;
// 2 -> 20
vi0
|=
(
vh
[
i
]
<<
25
)
&
0x10000000
;
// 3 -> 28
sumi
=
__dp4a
(
vi0
,
u
[
2
*
i
+
0
],
sumi
);
// SIMD dot product of quantized values
int
vi1
=
(
vl
[
i
]
>>
4
)
&
0x0F0F0F0F
;
// upper 4 qs bits, still need qh as 5th bits
vi1
|=
(
vh
[
i
]
>>
12
)
&
0x00000010
;
// 16 -> 4
vi1
|=
(
vh
[
i
]
>>
5
)
&
0x00001000
;
// 17 -> 12
vi1
|=
(
vh
[
i
]
<<
2
)
&
0x00100000
;
// 18 -> 20
vi1
|=
(
vh
[
i
]
<<
9
)
&
0x10000000
;
// 19 -> 28
sumi
=
__dp4a
(
vi1
,
u
[
2
*
i
+
1
],
sumi
);
// SIMD dot product of quantized values
}
const
float2
ds8f
=
__half22float2
(
ds8
);
// second part effectively subtracts 16 from each quant value
return
d5
*
(
sumi
*
ds8f
.
x
-
(
16
*
vdr
/
QI5_0
)
*
ds8f
.
y
);
#endif
}
#define VDR_Q5_1_Q8_1_MMVQ 2
#define VDR_Q5_1_Q8_1_MMQ 4
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q5_1_q8_1_impl
(
const
int
*
vl
,
const
int
*
vh
,
const
int
*
u
,
const
half2
&
dm5
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
int
vi0
=
(
vl
[
i
]
>>
0
)
&
0x0F0F0F0F
;
// lower 4 qs bits, still need qh as 5th bits
vi0
|=
(
vh
[
i
]
<<
4
)
&
0x00000010
;
// 0 -> 4
vi0
|=
(
vh
[
i
]
<<
11
)
&
0x00001000
;
// 1 -> 12
vi0
|=
(
vh
[
i
]
<<
18
)
&
0x00100000
;
// 2 -> 20
vi0
|=
(
vh
[
i
]
<<
25
)
&
0x10000000
;
// 3 -> 28
sumi
=
__dp4a
(
vi0
,
u
[
2
*
i
+
0
],
sumi
);
// SIMD dot product of quantized values
int
vi1
=
(
vl
[
i
]
>>
4
)
&
0x0F0F0F0F
;
// upper 4 qs bits, still need qh as 5th bits
vi1
|=
(
vh
[
i
]
>>
12
)
&
0x00000010
;
// 16 -> 4
vi1
|=
(
vh
[
i
]
>>
5
)
&
0x00001000
;
// 17 -> 12
vi1
|=
(
vh
[
i
]
<<
2
)
&
0x00100000
;
// 18 -> 20
vi1
|=
(
vh
[
i
]
<<
9
)
&
0x10000000
;
// 19 -> 28
sumi
=
__dp4a
(
vi1
,
u
[
2
*
i
+
1
],
sumi
);
// SIMD dot product of quantized values
}
const
float2
tmp
=
__half22float2
(
__hmul2
(
dm5
,
ds8
));
const
float
d5d8
=
tmp
.
x
;
const
float
m5s8
=
tmp
.
y
;
// scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
return
sumi
*
d5d8
+
m5s8
/
(
QI5_1
/
vdr
);
#endif
}
#define VDR_Q8_0_Q8_1_MMVQ 2
#define VDR_Q8_0_Q8_1_MMQ 8
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q8_0_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
float
&
d8_0
,
const
float
&
d8_1
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
// SIMD dot product of quantized values
sumi
=
__dp4a
(
v
[
i
],
u
[
i
],
sumi
);
}
return
d8_0
*
d8_1
*
sumi
;
#endif
}
template
<
int
vdr
>
static
__device__
__forceinline__
float
vec_dot_q8_1_q8_1_impl
(
const
int
*
v
,
const
int
*
u
,
const
half2
&
dm8
,
const
half2
&
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
int
sumi
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
vdr
;
++
i
)
{
// SIMD dot product of quantized values
sumi
=
__dp4a
(
v
[
i
],
u
[
i
],
sumi
);
}
const
float2
tmp
=
__half22float2
(
__hmul2
(
dm8
,
ds8
));
const
float
d8d8
=
tmp
.
x
;
const
float
m8s8
=
tmp
.
y
;
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
return
sumi
*
d8d8
+
m8s8
/
(
QI8_1
/
vdr
);
#endif
}
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 2
// contiguous v/x values
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1_impl_mmvq
(
const
int
&
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
half2
&
dm2
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR2_K
;
++
i
)
{
const
int
sc
=
scales
[
2
*
i
];
const
int
vi
=
(
v
>>
(
2
*
i
))
&
0x03030303
;
sumf_d
+=
d8
[
i
]
*
(
__dp4a
(
vi
,
u
[
i
],
0
)
*
(
sc
&
0xF
));
// SIMD dot product
// fill int with 4x m
int
m
=
sc
>>
4
;
m
|=
m
<<
8
;
m
|=
m
<<
16
;
sumf_m
+=
d8
[
i
]
*
__dp4a
(
m
,
u
[
i
],
0
);
// multiply constant q2_K part with sum of q8_1 values
}
const
float2
dm2f
=
__half22float2
(
dm2
);
return
dm2f
.
x
*
sumf_d
-
dm2f
.
y
*
sumf_m
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
half2
&
dm2
,
const
float
&
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
int
sumi_d
=
0
;
int
sumi_m
=
0
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
QI8_1
;
i0
+=
QI8_1
/
2
)
{
int
sumi_d_sc
=
0
;
const
int
sc
=
scales
[
i0
/
(
QI8_1
/
2
)];
// fill int with 4x m
int
m
=
sc
>>
4
;
m
|=
m
<<
8
;
m
|=
m
<<
16
;
#pragma unroll
for
(
int
i
=
i0
;
i
<
i0
+
QI8_1
/
2
;
++
i
)
{
sumi_d_sc
=
__dp4a
(
v
[
i
],
u
[
i
],
sumi_d_sc
);
// SIMD dot product
sumi_m
=
__dp4a
(
m
,
u
[
i
],
sumi_m
);
// multiply sum of q8_1 values with m
}
sumi_d
+=
sumi_d_sc
*
(
sc
&
0xF
);
}
const
float2
dm2f
=
__half22float2
(
dm2
);
return
d8
*
(
dm2f
.
x
*
sumi_d
-
dm2f
.
y
*
sumi_m
);
#endif
}
#define VDR_Q3_K_Q8_1_MMVQ 1
#define VDR_Q3_K_Q8_1_MMQ 2
// contiguous v/x values
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1_impl_mmvq
(
const
int
&
vl
,
const
int
&
vh
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
scales
,
const
int
&
scale_offset
,
const
float
&
d3
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
float
sumf
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR3_K
;
++
i
)
{
const
int
isc
=
scale_offset
+
2
*
i
;
const
int
isc_low
=
isc
%
(
QK_K
/
32
);
const
int
sc_shift_low
=
4
*
(
isc
/
(
QK_K
/
32
));
const
int
sc_low
=
(
scales
[
isc_low
]
>>
sc_shift_low
)
&
0xF
;
const
int
isc_high
=
isc
%
(
QK_K
/
64
);
const
int
sc_shift_high
=
2
*
(
isc
/
(
QK_K
/
64
));
const
int
sc_high
=
((
scales
[(
QK_K
/
32
)
+
isc_high
]
>>
sc_shift_high
)
&
3
)
<<
4
;
const
int
sc
=
(
sc_low
|
sc_high
)
-
32
;
const
int
vil
=
(
vl
>>
(
2
*
i
))
&
0x03030303
;
const
int
vih
=
((
vh
>>
i
)
<<
2
)
&
0x04040404
;
const
int
vi
=
__vsubss4
(
vil
,
vih
);
sumf
+=
d8
[
i
]
*
(
__dp4a
(
vi
,
u
[
i
],
0
)
*
sc
);
// SIMD dot product
}
return
d3
*
sumf
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
scales
,
const
float
&
d3
,
const
float
&
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
int
sumi
=
0
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
QR3_K
*
VDR_Q3_K_Q8_1_MMQ
;
i0
+=
QI8_1
/
2
)
{
int
sumi_sc
=
0
;
for
(
int
i
=
i0
;
i
<
i0
+
QI8_1
/
2
;
++
i
)
{
sumi_sc
=
__dp4a
(
v
[
i
],
u
[
i
],
sumi_sc
);
// SIMD dot product
}
sumi
+=
sumi_sc
*
scales
[
i0
/
(
QI8_1
/
2
)];
}
return
d3
*
d8
*
sumi
;
#endif
}
#define VDR_Q4_K_Q8_1_MMVQ 2
#define VDR_Q4_K_Q8_1_MMQ 8
// contiguous v/x values
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1_impl_vmmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR4_K
;
++
i
)
{
const
int
v0i
=
(
v
[
0
]
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
v1i
=
(
v
[
1
]
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
dot1
=
__dp4a
(
v1i
,
u
[
2
*
i
+
1
],
__dp4a
(
v0i
,
u
[
2
*
i
+
0
],
0
));
// SIMD dot product
const
int
dot2
=
__dp4a
(
0x01010101
,
u
[
2
*
i
+
1
],
__dp4a
(
0x01010101
,
u
[
2
*
i
+
0
],
0
));
// sum of u
sumf_d
+=
d8
[
i
]
*
(
dot1
*
sc
[
i
]);
sumf_m
+=
d8
[
i
]
*
(
dot2
*
m
[
i
]);
// multiply constant part of q4_K with sum of q8_1 values
}
const
float2
dm4f
=
__half22float2
(
dm4
);
return
dm4f
.
x
*
sumf_d
-
dm4f
.
y
*
sumf_m
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
half2
*
__restrict__
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR4_K
*
VDR_Q4_K_Q8_1_MMQ
/
QI8_1
;
++
i
)
{
int
sumi_d
=
0
;
#pragma unroll
for
(
int
j
=
0
;
j
<
QI8_1
;
++
j
)
{
sumi_d
=
__dp4a
((
v
[
j
]
>>
(
4
*
i
))
&
0x0F0F0F0F
,
u
[
i
*
QI8_1
+
j
],
sumi_d
);
// SIMD dot product
}
const
float2
ds8f
=
__half22float2
(
ds8
[
i
]);
sumf_d
+=
ds8f
.
x
*
(
sc
[
i
]
*
sumi_d
);
sumf_m
+=
ds8f
.
y
*
m
[
i
];
// sum of q8_1 block * q4_K min val
}
const
float2
dm4f
=
__half22float2
(
dm4
);
return
dm4f
.
x
*
sumf_d
-
dm4f
.
y
*
sumf_m
;
#endif
}
#define VDR_Q5_K_Q8_1_MMVQ 2
#define VDR_Q5_K_Q8_1_MMQ 8
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_impl_vmmq
(
const
int
*
__restrict__
vl
,
const
int
*
__restrict__
vh
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm5
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR5_K
;
++
i
)
{
const
int
vl0i
=
(
vl
[
0
]
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
vl1i
=
(
vl
[
1
]
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
vh0i
=
((
vh
[
0
]
>>
i
)
<<
4
)
&
0x10101010
;
const
int
vh1i
=
((
vh
[
1
]
>>
i
)
<<
4
)
&
0x10101010
;
const
int
v0i
=
vl0i
|
vh0i
;
const
int
v1i
=
vl1i
|
vh1i
;
const
int
dot1
=
__dp4a
(
v0i
,
u
[
2
*
i
+
0
],
__dp4a
(
v1i
,
u
[
2
*
i
+
1
],
0
));
// SIMD dot product
const
int
dot2
=
__dp4a
(
0x01010101
,
u
[
2
*
i
+
0
],
__dp4a
(
0x01010101
,
u
[
2
*
i
+
1
],
0
));
// sum of u
sumf_d
+=
d8
[
i
]
*
(
dot1
*
sc
[
i
]);
sumf_m
+=
d8
[
i
]
*
(
dot2
*
m
[
i
]);
}
const
float2
dm5f
=
__half22float2
(
dm5
);
return
dm5f
.
x
*
sumf_d
-
dm5f
.
y
*
sumf_m
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
uint8_t
*
__restrict__
sc
,
const
uint8_t
*
__restrict__
m
,
const
half2
&
dm4
,
const
half2
*
__restrict__
ds8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
float
sumf_d
=
0.0
f
;
float
sumf_m
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR5_K
*
VDR_Q5_K_Q8_1_MMQ
/
QI8_1
;
++
i
)
{
int
sumi_d
=
0
;
#pragma unroll
for
(
int
j
=
0
;
j
<
QI8_1
;
++
j
)
{
sumi_d
=
__dp4a
(
v
[
i
*
QI8_1
+
j
],
u
[
i
*
QI8_1
+
j
],
sumi_d
);
// SIMD dot product
}
const
float2
ds8f
=
__half22float2
(
ds8
[
i
]);
sumf_d
+=
ds8f
.
x
*
(
sc
[
i
]
*
sumi_d
);
sumf_m
+=
ds8f
.
y
*
m
[
i
];
// sum of q8_1 block * q4_K min val
}
const
float2
dm4f
=
__half22float2
(
dm4
);
return
dm4f
.
x
*
sumf_d
-
dm4f
.
y
*
sumf_m
;
#endif
}
#define VDR_Q6_K_Q8_1_MMVQ 1
#define VDR_Q6_K_Q8_1_MMQ 8
// contiguous v/x values
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1_impl_mmvq
(
const
int
&
vl
,
const
int
&
vh
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
scales
,
const
float
&
d
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
float
sumf
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR6_K
;
++
i
)
{
const
int
sc
=
scales
[
4
*
i
];
const
int
vil
=
(
vl
>>
(
4
*
i
))
&
0x0F0F0F0F
;
const
int
vih
=
((
vh
>>
(
4
*
i
))
<<
4
)
&
0x30303030
;
const
int
vi
=
__vsubss4
((
vil
|
vih
),
0x20202020
);
// vi = (vil | vih) - 32
sumf
+=
d8
[
i
]
*
(
__dp4a
(
vi
,
u
[
i
],
0
)
*
sc
);
// SIMD dot product
}
return
d
*
sumf
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1_impl_mmq
(
const
int
*
__restrict__
v
,
const
int
*
__restrict__
u
,
const
int8_t
*
__restrict__
sc
,
const
float
&
d6
,
const
float
*
__restrict__
d8
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
float
sumf_d
=
0.0
f
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
VDR_Q6_K_Q8_1_MMQ
;
i0
+=
4
)
{
int2
sumi_d
=
{
0
,
0
};
// 2 q6_K scales per q8_1 scale
#pragma unroll
for
(
int
i
=
i0
;
i
<
i0
+
2
;
++
i
)
{
sumi_d
.
x
=
__dp4a
(
v
[
2
*
i
+
0
],
u
[
2
*
i
+
0
],
sumi_d
.
x
);
// SIMD dot product
sumi_d
.
x
=
__dp4a
(
v
[
2
*
i
+
1
],
u
[
2
*
i
+
1
],
sumi_d
.
x
);
// SIMD dot product
sumi_d
.
y
=
__dp4a
(
v
[
2
*
i
+
4
],
u
[
2
*
i
+
4
],
sumi_d
.
y
);
// SIMD dot product
sumi_d
.
y
=
__dp4a
(
v
[
2
*
i
+
5
],
u
[
2
*
i
+
5
],
sumi_d
.
y
);
// SIMD dot product
}
sumf_d
+=
d8
[
i0
/
4
]
*
(
sc
[
i0
/
2
+
0
]
*
sumi_d
.
x
+
sc
[
i0
/
2
+
1
]
*
sumi_d
.
y
);
}
return
d6
*
sumf_d
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_q4_0_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q4_0
*
bq4_0
=
(
const
block_q4_0
*
)
vbq
;
int
v
[
VDR_Q4_0_Q8_1_MMVQ
];
int
u
[
2
*
VDR_Q4_0_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q4_0_Q8_1_MMVQ
;
++
i
)
{
v
[
i
]
=
get_int_from_uint8
(
bq4_0
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
0
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
1
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
+
QI4_0
);
}
return
vec_dot_q4_0_q8_1_impl
<
VDR_Q4_0_Q8_1_MMVQ
>
(
v
,
u
,
__half2float
(
bq4_0
->
d
),
bq8_1
->
ds
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
/
QI4_0
)
+
mmq_y
/
QI4_0
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q4_0
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI4_0
;
const
int
kqsx
=
k
%
QI4_0
;
const
block_q4_0
*
bx0
=
(
const
block_q4_0
*
)
vx
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
// x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI4_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI4_0
)
{
int
i
=
i0
+
i_offset
*
QI4_0
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI4_0
)
+
i
/
QI4_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
static
__device__
__forceinline__
float
vec_dot_q4_0_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
(
void
)
x_qh
;
(
void
)
x_sc
;
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
int
u
[
2
*
VDR_Q4_0_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q4_0_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
)
%
WARP_SIZE
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
+
QI4_0
)
%
WARP_SIZE
];
}
return
vec_dot_q4_0_q8_1_impl
<
VDR_Q4_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
],
u
,
x_dmf
[
i
*
(
WARP_SIZE
/
QI4_0
)
+
i
/
QI4_0
+
k
/
QI4_0
],
y_ds
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q4_1_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q4_1
*
bq4_1
=
(
const
block_q4_1
*
)
vbq
;
int
v
[
VDR_Q4_1_Q8_1_MMVQ
];
int
u
[
2
*
VDR_Q4_1_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q4_1_Q8_1_MMVQ
;
++
i
)
{
v
[
i
]
=
get_int_from_uint8_aligned
(
bq4_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
0
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
1
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
+
QI4_1
);
}
return
vec_dot_q4_1_q8_1_impl
<
VDR_Q4_1_Q8_1_MMVQ
>
(
v
,
u
,
bq4_1
->
dm
,
bq8_1
->
ds
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_1
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
)
+
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI4_1
)
+
mmq_y
/
QI4_1
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
tile_x_dm
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q4_1
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI4_1
;
const
int
kqsx
=
k
%
QI4_1
;
const
block_q4_1
*
bx0
=
(
const
block_q4_1
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI4_1
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI4_1
)
{
int
i
=
i0
+
i_offset
*
QI4_1
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI4_1
)
+
i
/
QI4_1
+
kbxd
]
=
bxi
->
dm
;
}
}
static
__device__
__forceinline__
float
vec_dot_q4_1_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
int
u
[
2
*
VDR_Q4_1_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q4_1_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
)
%
WARP_SIZE
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
+
QI4_1
)
%
WARP_SIZE
];
}
return
vec_dot_q4_1_q8_1_impl
<
VDR_Q4_1_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
],
u
,
x_dm
[
i
*
(
WARP_SIZE
/
QI4_1
)
+
i
/
QI4_1
+
k
/
QI4_1
],
y_ds
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q5_0_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q5_0
*
bq5_0
=
(
const
block_q5_0
*
)
vbq
;
int
vl
[
VDR_Q5_0_Q8_1_MMVQ
];
int
vh
[
VDR_Q5_0_Q8_1_MMVQ
];
int
u
[
2
*
VDR_Q5_0_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q5_0_Q8_1_MMVQ
;
++
i
)
{
vl
[
i
]
=
get_int_from_uint8
(
bq5_0
->
qs
,
iqs
+
i
);
vh
[
i
]
=
get_int_from_uint8
(
bq5_0
->
qh
,
0
)
>>
(
4
*
(
iqs
+
i
));
u
[
2
*
i
+
0
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
1
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
+
QI5_0
);
}
return
vec_dot_q5_0_q8_1_impl
<
VDR_Q5_0_Q8_1_MMVQ
>
(
vl
,
vh
,
u
,
__half2float
(
bq5_0
->
d
),
bq8_1
->
ds
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
/
QI5_0
)
+
mmq_y
/
QI5_0
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q5_0
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI5_0
;
const
int
kqsx
=
k
%
QI5_0
;
const
block_q5_0
*
bx0
=
(
const
block_q5_0
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
const
int
ql
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
const
int
qh
=
get_int_from_uint8
(
bxi
->
qh
,
0
)
>>
(
4
*
(
k
%
QI5_0
));
int
qs0
=
(
ql
>>
0
)
&
0x0F0F0F0F
;
qs0
|=
(
qh
<<
4
)
&
0x00000010
;
// 0 -> 4
qs0
|=
(
qh
<<
11
)
&
0x00001000
;
// 1 -> 12
qs0
|=
(
qh
<<
18
)
&
0x00100000
;
// 2 -> 20
qs0
|=
(
qh
<<
25
)
&
0x10000000
;
// 3 -> 28
qs0
=
__vsubss4
(
qs0
,
0x10101010
);
// subtract 16
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
+
0
]
=
qs0
;
int
qs1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
qs1
|=
(
qh
>>
12
)
&
0x00000010
;
// 16 -> 4
qs1
|=
(
qh
>>
5
)
&
0x00001000
;
// 17 -> 12
qs1
|=
(
qh
<<
2
)
&
0x00100000
;
// 18 -> 20
qs1
|=
(
qh
<<
9
)
&
0x10000000
;
// 19 -> 28
qs1
=
__vsubss4
(
qs1
,
0x10101010
);
// subtract 16
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
+
1
]
=
qs1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI5_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI5_0
)
{
int
i
=
i0
+
i_offset
*
QI5_0
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI5_0
)
+
i
/
QI5_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
static
__device__
__forceinline__
float
vec_dot_q5_0_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
const
int
index_bx
=
i
*
(
WARP_SIZE
/
QI5_0
)
+
i
/
QI5_0
+
k
/
QI5_0
;
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
int
u
[
2
*
VDR_Q5_0_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q5_0_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
)
%
WARP_SIZE
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
+
QI5_0
)
%
WARP_SIZE
];
}
return
vec_dot_q8_0_q8_1_impl
<
QR5_0
*
VDR_Q5_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
],
u
,
x_dmf
[
index_bx
],
y_df
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q5_1_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q5_1
*
bq5_1
=
(
const
block_q5_1
*
)
vbq
;
int
vl
[
VDR_Q5_1_Q8_1_MMVQ
];
int
vh
[
VDR_Q5_1_Q8_1_MMVQ
];
int
u
[
2
*
VDR_Q5_1_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q5_1_Q8_1_MMVQ
;
++
i
)
{
vl
[
i
]
=
get_int_from_uint8_aligned
(
bq5_1
->
qs
,
iqs
+
i
);
vh
[
i
]
=
get_int_from_uint8_aligned
(
bq5_1
->
qh
,
0
)
>>
(
4
*
(
iqs
+
i
));
u
[
2
*
i
+
0
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
u
[
2
*
i
+
1
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
+
QI5_1
);
}
return
vec_dot_q5_1_q8_1_impl
<
VDR_Q5_1_Q8_1_MMVQ
>
(
vl
,
vh
,
u
,
bq5_1
->
dm
,
bq8_1
->
ds
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_1
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI5_1
)
+
mmq_y
/
QI5_1
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q5_1
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI5_1
;
const
int
kqsx
=
k
%
QI5_1
;
const
block_q5_1
*
bx0
=
(
const
block_q5_1
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
const
int
ql
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
const
int
qh
=
get_int_from_uint8_aligned
(
bxi
->
qh
,
0
)
>>
(
4
*
(
k
%
QI5_1
));
int
qs0
=
(
ql
>>
0
)
&
0x0F0F0F0F
;
qs0
|=
(
qh
<<
4
)
&
0x00000010
;
// 0 -> 4
qs0
|=
(
qh
<<
11
)
&
0x00001000
;
// 1 -> 12
qs0
|=
(
qh
<<
18
)
&
0x00100000
;
// 2 -> 20
qs0
|=
(
qh
<<
25
)
&
0x10000000
;
// 3 -> 28
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
+
0
]
=
qs0
;
int
qs1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
qs1
|=
(
qh
>>
12
)
&
0x00000010
;
// 16 -> 4
qs1
|=
(
qh
>>
5
)
&
0x00001000
;
// 17 -> 12
qs1
|=
(
qh
<<
2
)
&
0x00100000
;
// 18 -> 20
qs1
|=
(
qh
<<
9
)
&
0x10000000
;
// 19 -> 28
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
+
1
]
=
qs1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI5_1
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI5_1
)
{
int
i
=
i0
+
i_offset
*
QI5_1
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_1
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI5_1
)
+
i
/
QI5_1
+
kbxd
]
=
bxi
->
dm
;
}
}
static
__device__
__forceinline__
float
vec_dot_q5_1_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kyqs
=
k
%
(
QI8_1
/
2
)
+
QI8_1
*
(
k
/
(
QI8_1
/
2
));
const
int
index_bx
=
i
*
(
WARP_SIZE
/
QI5_1
)
+
+
i
/
QI5_1
+
k
/
QI5_1
;
int
u
[
2
*
VDR_Q5_1_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
VDR_Q5_1_Q8_1_MMQ
;
++
l
)
{
u
[
2
*
l
+
0
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
)
%
WARP_SIZE
];
u
[
2
*
l
+
1
]
=
y_qs
[
j
*
WARP_SIZE
+
(
kyqs
+
l
+
QI5_1
)
%
WARP_SIZE
];
}
return
vec_dot_q8_1_q8_1_impl
<
QR5_1
*
VDR_Q5_1_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
2
*
k
],
u
,
x_dm
[
index_bx
],
y_ds
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
(
2
*
k
/
QI8_1
)
%
(
WARP_SIZE
/
QI8_1
)]);
}
static
__device__
__forceinline__
float
vec_dot_q8_0_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q8_0
*
bq8_0
=
(
const
block_q8_0
*
)
vbq
;
int
v
[
VDR_Q8_0_Q8_1_MMVQ
];
int
u
[
VDR_Q8_0_Q8_1_MMVQ
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VDR_Q8_0_Q8_1_MMVQ
;
++
i
)
{
v
[
i
]
=
get_int_from_int8
(
bq8_0
->
qs
,
iqs
+
i
);
u
[
i
]
=
get_int_from_int8_aligned
(
bq8_1
->
qs
,
iqs
+
i
);
}
return
vec_dot_q8_0_q8_1_impl
<
VDR_Q8_0_Q8_1_MMVQ
>
(
v
,
u
,
__half2float
(
bq8_0
->
d
),
__low2float
(
bq8_1
->
ds
));
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q8_0
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_qs
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
float
tile_x_d
[
mmq_y
*
(
WARP_SIZE
/
QI8_0
)
+
mmq_y
/
QI8_0
];
*
x_ql
=
tile_x_qs
;
*
x_dm
=
(
half2
*
)
tile_x_d
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q8_0
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI8_0
;
const
int
kqsx
=
k
%
QI8_0
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
const
block_q8_0
*
bx0
=
(
const
block_q8_0
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q8_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_int8
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI8_0
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI8_0
)
{
int
i
=
i0
+
i_offset
*
QI8_0
+
k
/
blocks_per_tile_x_row
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q8_0
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI8_0
)
+
i
/
QI8_0
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
}
static
__device__
__forceinline__
float
vec_dot_q8_0_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
return
vec_dot_q8_0_q8_1_impl
<
VDR_Q8_0_Q8_1_MMQ
>
(
&
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
],
&
y_qs
[
j
*
WARP_SIZE
+
k
],
x_dmf
[
i
*
(
WARP_SIZE
/
QI8_0
)
+
i
/
QI8_0
+
k
/
QI8_0
],
y_df
[
j
*
(
WARP_SIZE
/
QI8_1
)
+
k
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q2_K
*
bq2_K
=
(
const
block_q2_K
*
)
vbq
;
const
int
bq8_offset
=
QR2_K
*
(
iqs
/
QI8_1
);
const
int
scale_offset
=
iqs
-
iqs
%
QI8_1
+
(
iqs
%
QI8_1
)
/
(
QI8_1
/
2
);
const
uint8_t
*
scales
=
bq2_K
->
scales
+
scale_offset
;
const
int
v
=
get_int_from_uint8_aligned
(
bq2_K
->
qs
,
iqs
);
int
u
[
QR2_K
];
float
d8
[
QR2_K
];
#pragma unroll
for
(
int
i
=
0
;
i
<
QR2_K
;
++
i
)
{
u
[
i
]
=
get_int_from_int8_aligned
(
bq8_1
[
bq8_offset
+
i
].
qs
,
iqs
%
QI8_1
);
d8
[
i
]
=
__low2float
(
bq8_1
[
bq8_offset
+
i
].
ds
);
}
return
vec_dot_q2_K_q8_1_impl_mmvq
(
v
,
u
,
scales
,
bq2_K
->
dm
,
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q2_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI2_K
)
+
mmq_y
/
QI2_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
4
)
+
mmq_y
/
4
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q2_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI2_K
;
const
int
kqsx
=
k
%
QI2_K
;
const
block_q2_K
*
bx0
=
(
const
block_q2_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI2_K
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI2_K
)
{
int
i
=
(
i0
+
i_offset
*
QI2_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI2_K
)
+
i
/
QI2_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
4
)
{
int
i
=
i0
+
i_offset
*
4
+
k
/
(
WARP_SIZE
/
4
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q2_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
4
))
/
(
QI2_K
/
4
);
x_sc
[
i
*
(
WARP_SIZE
/
4
)
+
i
/
4
+
k
%
(
WARP_SIZE
/
4
)]
=
get_int_from_uint8_aligned
(
bxi
->
scales
,
k
%
(
QI2_K
/
4
));
}
}
static
__device__
__forceinline__
float
vec_dot_q2_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kbx
=
k
/
QI2_K
;
const
int
ky
=
(
k
%
QI2_K
)
*
QR2_K
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
int
v
[
QR2_K
*
VDR_Q2_K_Q8_1_MMQ
];
const
int
kqsx
=
i
*
(
WARP_SIZE
+
1
)
+
kbx
*
QI2_K
+
(
QI2_K
/
2
)
*
(
ky
/
(
2
*
QI2_K
))
+
ky
%
(
QI2_K
/
2
);
const
int
shift
=
2
*
((
ky
%
(
2
*
QI2_K
))
/
(
QI2_K
/
2
));
#pragma unroll
for
(
int
l
=
0
;
l
<
QR2_K
*
VDR_Q2_K_Q8_1_MMQ
;
++
l
)
{
v
[
l
]
=
(
x_ql
[
kqsx
+
l
]
>>
shift
)
&
0x03030303
;
}
const
uint8_t
*
scales
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
/
4
)
+
i
/
4
+
kbx
*
4
])
+
ky
/
4
;
const
int
index_y
=
j
*
WARP_SIZE
+
(
QR2_K
*
k
)
%
WARP_SIZE
;
return
vec_dot_q2_K_q8_1_impl_mmq
(
v
,
&
y_qs
[
index_y
],
scales
,
x_dm
[
i
*
(
WARP_SIZE
/
QI2_K
)
+
i
/
QI2_K
+
kbx
],
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q3_K
*
bq3_K
=
(
const
block_q3_K
*
)
vbq
;
const
int
bq8_offset
=
QR3_K
*
(
iqs
/
(
QI3_K
/
2
));
const
int
scale_offset
=
iqs
-
iqs
%
QI8_1
+
(
iqs
%
QI8_1
)
/
(
QI8_1
/
2
);
const
float
d
=
__half2float
(
bq3_K
->
d
);
const
int
vl
=
get_int_from_uint8
(
bq3_K
->
qs
,
iqs
);
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
const
int
vh
=
~
get_int_from_uint8
(
bq3_K
->
hmask
,
iqs
%
(
QI3_K
/
2
))
>>
bq8_offset
;
int
u
[
QR3_K
];
float
d8
[
QR3_K
];
#pragma unroll
for
(
int
i
=
0
;
i
<
QR3_K
;
++
i
)
{
u
[
i
]
=
get_int_from_int8_aligned
(
bq8_1
[
bq8_offset
+
i
].
qs
,
iqs
%
QI8_1
);
d8
[
i
]
=
__low2float
(
bq8_1
[
bq8_offset
+
i
].
ds
);
}
return
vec_dot_q3_K_q8_1_impl_mmvq
(
vl
,
vh
,
u
,
bq3_K
->
scales
,
scale_offset
,
d
,
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q3_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI3_K
)
+
mmq_y
/
QI3_K
];
__shared__
int
tile_x_qh
[
mmq_y
*
(
WARP_SIZE
/
2
)
+
mmq_y
/
2
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
4
)
+
mmq_y
/
4
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_qh
=
tile_x_qh
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q3_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI3_K
;
const
int
kqsx
=
k
%
QI3_K
;
const
block_q3_K
*
bx0
=
(
const
block_q3_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI3_K
;
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
float
*
x_dmf
=
(
float
*
)
x_dm
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI3_K
)
{
int
i
=
(
i0
+
i_offset
*
QI3_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI3_K
)
+
i
/
QI3_K
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
2
)
{
int
i
=
i0
+
i_offset
*
2
+
k
/
(
WARP_SIZE
/
2
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
2
))
/
(
QI3_K
/
2
);
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
x_qh
[
i
*
(
WARP_SIZE
/
2
)
+
i
/
2
+
k
%
(
WARP_SIZE
/
2
)]
=
~
get_int_from_uint8
(
bxi
->
hmask
,
k
%
(
QI3_K
/
2
));
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
4
)
{
int
i
=
i0
+
i_offset
*
4
+
k
/
(
WARP_SIZE
/
4
);
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q3_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
4
))
/
(
QI3_K
/
4
);
const
int
ksc
=
k
%
(
QI3_K
/
4
);
const
int
ksc_low
=
ksc
%
(
QI3_K
/
8
);
const
int
shift_low
=
4
*
(
ksc
/
(
QI3_K
/
8
));
const
int
sc_low
=
(
get_int_from_uint8
(
bxi
->
scales
,
ksc_low
)
>>
shift_low
)
&
0x0F0F0F0F
;
const
int
ksc_high
=
QI3_K
/
8
;
const
int
shift_high
=
2
*
ksc
;
const
int
sc_high
=
((
get_int_from_uint8
(
bxi
->
scales
,
ksc_high
)
>>
shift_high
)
<<
4
)
&
0x30303030
;
const
int
sc
=
__vsubss4
(
sc_low
|
sc_high
,
0x20202020
);
x_sc
[
i
*
(
WARP_SIZE
/
4
)
+
i
/
4
+
k
%
(
WARP_SIZE
/
4
)]
=
sc
;
}
}
static
__device__
__forceinline__
float
vec_dot_q3_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
int
kbx
=
k
/
QI3_K
;
const
int
ky
=
(
k
%
QI3_K
)
*
QR3_K
;
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
const
int8_t
*
scales
=
((
const
int8_t
*
)
(
x_sc
+
i
*
(
WARP_SIZE
/
4
)
+
i
/
4
+
kbx
*
4
))
+
ky
/
4
;
int
v
[
QR3_K
*
VDR_Q3_K_Q8_1_MMQ
];
#pragma unroll
for
(
int
l
=
0
;
l
<
QR3_K
*
VDR_Q3_K_Q8_1_MMQ
;
++
l
)
{
const
int
kqsx
=
i
*
(
WARP_SIZE
+
1
)
+
kbx
*
QI3_K
+
(
QI3_K
/
2
)
*
(
ky
/
(
2
*
QI3_K
))
+
ky
%
(
QI3_K
/
2
);
const
int
shift
=
2
*
((
ky
%
32
)
/
8
);
const
int
vll
=
(
x_ql
[
kqsx
+
l
]
>>
shift
)
&
0x03030303
;
const
int
vh
=
x_qh
[
i
*
(
WARP_SIZE
/
2
)
+
i
/
2
+
kbx
*
(
QI3_K
/
2
)
+
(
ky
+
l
)
%
8
]
>>
((
ky
+
l
)
/
8
);
const
int
vlh
=
(
vh
<<
2
)
&
0x04040404
;
v
[
l
]
=
__vsubss4
(
vll
,
vlh
);
}
const
int
index_y
=
j
*
WARP_SIZE
+
(
k
*
QR3_K
)
%
WARP_SIZE
;
return
vec_dot_q3_K_q8_1_impl_mmq
(
v
,
&
y_qs
[
index_y
],
scales
,
x_dmf
[
i
*
(
WARP_SIZE
/
QI3_K
)
+
i
/
QI3_K
+
kbx
],
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q4_K
*
bq4_K
=
(
const
block_q4_K
*
)
vbq
;
int
v
[
2
];
int
u
[
2
*
QR4_K
];
float
d8
[
QR4_K
];
// iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
const
int
bq8_offset
=
QR4_K
*
((
iqs
/
2
)
/
(
QI8_1
/
2
));
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
const
int
*
q4
=
(
const
int
*
)(
bq4_K
->
qs
+
16
*
bq8_offset
+
4
*
((
iqs
/
2
)
%
4
));
v
[
0
]
=
q4
[
0
];
v
[
1
]
=
q4
[
4
];
const
uint16_t
*
scales
=
(
const
uint16_t
*
)
bq4_K
->
scales
;
uint16_t
aux
[
2
];
const
int
j
=
bq8_offset
/
2
;
if
(
j
<
2
)
{
aux
[
0
]
=
scales
[
j
+
0
]
&
0x3f3f
;
aux
[
1
]
=
scales
[
j
+
2
]
&
0x3f3f
;
}
else
{
aux
[
0
]
=
((
scales
[
j
+
2
]
>>
0
)
&
0x0f0f
)
|
((
scales
[
j
-
2
]
&
0xc0c0
)
>>
2
);
aux
[
1
]
=
((
scales
[
j
+
2
]
>>
4
)
&
0x0f0f
)
|
((
scales
[
j
-
0
]
&
0xc0c0
)
>>
2
);
}
const
uint8_t
*
sc
=
(
const
uint8_t
*
)
aux
;
const
uint8_t
*
m
=
sc
+
2
;
for
(
int
i
=
0
;
i
<
QR4_K
;
++
i
)
{
const
block_q8_1
*
bq8i
=
bq8_1
+
bq8_offset
+
i
;
d8
[
i
]
=
__low2float
(
bq8i
->
ds
);
const
int
*
q8
=
(
const
int
*
)
bq8i
->
qs
+
((
iqs
/
2
)
%
4
);
u
[
2
*
i
+
0
]
=
q8
[
0
];
u
[
2
*
i
+
1
]
=
q8
[
4
];
}
return
vec_dot_q4_K_q8_1_impl_vmmq
(
v
,
u
,
sc
,
m
,
bq4_K
->
dm
,
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q4_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI4_K
)
+
mmq_y
/
QI4_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q4_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI4_K
;
// == 0 if QK_K == 256
const
int
kqsx
=
k
%
QI4_K
;
// == k if QK_K == 256
const
block_q4_K
*
bx0
=
(
const
block_q4_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
]
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI4_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI4_K
)
{
int
i
=
(
i0
+
i_offset
*
QI4_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI4_K
)
+
i
/
QI4_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q4_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
8
))
/
(
QI4_K
/
8
);
const
int
*
scales
=
(
const
int
*
)
bxi
->
scales
;
const
int
ksc
=
k
%
(
WARP_SIZE
/
8
);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int
scales8
=
(
scales
[(
ksc
%
2
)
+
(
ksc
!=
0
)]
>>
(
4
*
(
ksc
&
(
ksc
/
2
))))
&
0x0F0F0F0F
;
// lower 4 bits
scales8
|=
(
scales
[
ksc
/
2
]
>>
(
2
*
(
ksc
%
2
)))
&
0x30303030
;
// upper 2 bits
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
ksc
]
=
scales8
;
}
}
static
__device__
__forceinline__
float
vec_dot_q4_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
(
void
)
x_qh
;
const
uint8_t
*
sc
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
k
/
16
])
+
2
*
((
k
%
16
)
/
8
);
const
int
index_y
=
j
*
WARP_SIZE
+
(
QR4_K
*
k
)
%
WARP_SIZE
;
return
vec_dot_q4_K_q8_1_impl_mmq
(
&
x_ql
[
i
*
(
WARP_SIZE
+
1
)
+
k
],
&
y_qs
[
index_y
],
sc
,
sc
+
8
,
x_dm
[
i
*
(
WARP_SIZE
/
QI4_K
)
+
i
/
QI4_K
],
&
y_ds
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q5_K
*
bq5_K
=
(
const
block_q5_K
*
)
vbq
;
int
vl
[
2
];
int
vh
[
2
];
int
u
[
2
*
QR5_K
];
float
d8
[
QR5_K
];
const
int
bq8_offset
=
QR5_K
*
((
iqs
/
2
)
/
(
QI8_1
/
2
));
const
int
*
ql
=
(
const
int
*
)(
bq5_K
->
qs
+
16
*
bq8_offset
+
4
*
((
iqs
/
2
)
%
4
));
const
int
*
qh
=
(
const
int
*
)(
bq5_K
->
qh
+
4
*
((
iqs
/
2
)
%
4
));
vl
[
0
]
=
ql
[
0
];
vl
[
1
]
=
ql
[
4
];
vh
[
0
]
=
qh
[
0
]
>>
bq8_offset
;
vh
[
1
]
=
qh
[
4
]
>>
bq8_offset
;
const
uint16_t
*
scales
=
(
const
uint16_t
*
)
bq5_K
->
scales
;
uint16_t
aux
[
2
];
const
int
j
=
bq8_offset
/
2
;
if
(
j
<
2
)
{
aux
[
0
]
=
scales
[
j
+
0
]
&
0x3f3f
;
aux
[
1
]
=
scales
[
j
+
2
]
&
0x3f3f
;
}
else
{
aux
[
0
]
=
((
scales
[
j
+
2
]
>>
0
)
&
0x0f0f
)
|
((
scales
[
j
-
2
]
&
0xc0c0
)
>>
2
);
aux
[
1
]
=
((
scales
[
j
+
2
]
>>
4
)
&
0x0f0f
)
|
((
scales
[
j
-
0
]
&
0xc0c0
)
>>
2
);
}
const
uint8_t
*
sc
=
(
const
uint8_t
*
)
aux
;
const
uint8_t
*
m
=
sc
+
2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
QR5_K
;
++
i
)
{
const
block_q8_1
*
bq8i
=
bq8_1
+
bq8_offset
+
i
;
d8
[
i
]
=
__low2float
(
bq8i
->
ds
);
const
int
*
q8
=
(
const
int
*
)
bq8i
->
qs
+
((
iqs
/
2
)
%
4
);
u
[
2
*
i
+
0
]
=
q8
[
0
];
u
[
2
*
i
+
1
]
=
q8
[
4
];
}
return
vec_dot_q5_K_q8_1_impl_vmmq
(
vl
,
vh
,
u
,
sc
,
m
,
bq5_K
->
dm
,
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q5_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI5_K
)
+
mmq_y
/
QI5_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q5_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI5_K
;
// == 0 if QK_K == 256
const
int
kqsx
=
k
%
QI5_K
;
// == k if QK_K == 256
const
block_q5_K
*
bx0
=
(
const
block_q5_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
const
int
ky
=
QR5_K
*
kqsx
;
const
int
ql
=
get_int_from_uint8_aligned
(
bxi
->
qs
,
kqsx
);
const
int
ql0
=
(
ql
>>
0
)
&
0x0F0F0F0F
;
const
int
ql1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
const
int
qh
=
get_int_from_uint8_aligned
(
bxi
->
qh
,
kqsx
%
(
QI5_K
/
4
));
const
int
qh0
=
((
qh
>>
(
2
*
(
kqsx
/
(
QI5_K
/
4
))
+
0
))
<<
4
)
&
0x10101010
;
const
int
qh1
=
((
qh
>>
(
2
*
(
kqsx
/
(
QI5_K
/
4
))
+
1
))
<<
4
)
&
0x10101010
;
const
int
kq0
=
ky
-
ky
%
(
QI5_K
/
2
)
+
k
%
(
QI5_K
/
4
)
+
0
;
const
int
kq1
=
ky
-
ky
%
(
QI5_K
/
2
)
+
k
%
(
QI5_K
/
4
)
+
(
QI5_K
/
4
);
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
kq0
]
=
ql0
|
qh0
;
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
kq1
]
=
ql1
|
qh1
;
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI5_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI5_K
)
{
int
i
=
(
i0
+
i_offset
*
QI5_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dm
[
i
*
(
WARP_SIZE
/
QI5_K
)
+
i
/
QI5_K
+
kbxd
]
=
bxi
->
dm
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q5_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
8
))
/
(
QI5_K
/
8
);
const
int
*
scales
=
(
const
int
*
)
bxi
->
scales
;
const
int
ksc
=
k
%
(
WARP_SIZE
/
8
);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int
scales8
=
(
scales
[(
ksc
%
2
)
+
(
ksc
!=
0
)]
>>
(
4
*
(
ksc
&
(
ksc
/
2
))))
&
0x0F0F0F0F
;
// lower 4 bits
scales8
|=
(
scales
[
ksc
/
2
]
>>
(
2
*
(
ksc
%
2
)))
&
0x30303030
;
// upper 2 bits
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
ksc
]
=
scales8
;
}
}
static
__device__
__forceinline__
float
vec_dot_q5_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
uint8_t
*
sc
=
((
const
uint8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
k
/
16
])
+
2
*
((
k
%
16
)
/
8
);
const
int
index_x
=
i
*
(
QR5_K
*
WARP_SIZE
+
1
)
+
QR5_K
*
k
;
const
int
index_y
=
j
*
WARP_SIZE
+
(
QR5_K
*
k
)
%
WARP_SIZE
;
return
vec_dot_q5_K_q8_1_impl_mmq
(
&
x_ql
[
index_x
],
&
y_qs
[
index_y
],
sc
,
sc
+
8
,
x_dm
[
i
*
(
WARP_SIZE
/
QI5_K
)
+
i
/
QI5_K
],
&
y_ds
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_q6_K
*
bq6_K
=
(
const
block_q6_K
*
)
vbq
;
const
int
bq8_offset
=
2
*
QR6_K
*
(
iqs
/
(
QI6_K
/
2
))
+
(
iqs
%
(
QI6_K
/
2
))
/
(
QI6_K
/
4
);
const
int
scale_offset
=
(
QI6_K
/
4
)
*
(
iqs
/
(
QI6_K
/
2
))
+
(
iqs
%
(
QI6_K
/
2
))
/
(
QI6_K
/
8
);
const
int
vh_shift
=
2
*
((
iqs
%
(
QI6_K
/
2
))
/
(
QI6_K
/
4
));
const
int
vl
=
get_int_from_uint8
(
bq6_K
->
ql
,
iqs
);
const
int
vh
=
get_int_from_uint8
(
bq6_K
->
qh
,
(
QI6_K
/
4
)
*
(
iqs
/
(
QI6_K
/
2
))
+
iqs
%
(
QI6_K
/
4
))
>>
vh_shift
;
const
int8_t
*
scales
=
bq6_K
->
scales
+
scale_offset
;
int
u
[
QR6_K
];
float
d8
[
QR6_K
];
#pragma unroll
for
(
int
i
=
0
;
i
<
QR6_K
;
++
i
)
{
u
[
i
]
=
get_int_from_int8_aligned
(
bq8_1
[
bq8_offset
+
2
*
i
].
qs
,
iqs
%
QI8_1
);
d8
[
i
]
=
__low2float
(
bq8_1
[
bq8_offset
+
2
*
i
].
ds
);
}
return
vec_dot_q6_K_q8_1_impl_mmvq
(
vl
,
vh
,
u
,
scales
,
__half2float
(
bq6_K
->
d
),
d8
);
}
template
<
int
mmq_y
>
static
__device__
__forceinline__
void
allocate_tiles_q6_K
(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
)
{
__shared__
int
tile_x_ql
[
mmq_y
*
(
2
*
WARP_SIZE
)
+
mmq_y
];
__shared__
half2
tile_x_dm
[
mmq_y
*
(
WARP_SIZE
/
QI6_K
)
+
mmq_y
/
QI6_K
];
__shared__
int
tile_x_sc
[
mmq_y
*
(
WARP_SIZE
/
8
)
+
mmq_y
/
8
];
*
x_ql
=
tile_x_ql
;
*
x_dm
=
tile_x_dm
;
*
x_sc
=
tile_x_sc
;
}
template
<
int
mmq_y
,
int
nwarps
,
bool
need_check
>
static
__device__
__forceinline__
void
load_tiles_q6_K
(
const
void
*
__restrict__
vx
,
int
*
__restrict__
x_ql
,
half2
*
__restrict__
x_dm
,
int
*
__restrict__
x_qh
,
int
*
__restrict__
x_sc
,
const
int
&
i_offset
,
const
int
&
i_max
,
const
int
&
k
,
const
int
&
blocks_per_row
)
{
const
int
kbx
=
k
/
QI6_K
;
// == 0 if QK_K == 256
const
int
kqsx
=
k
%
QI6_K
;
// == k if QK_K == 256
const
block_q6_K
*
bx0
=
(
const
block_q6_K
*
)
vx
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
)
{
int
i
=
i0
+
i_offset
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbx
;
const
int
ky
=
QR6_K
*
kqsx
;
const
int
ql
=
get_int_from_uint8
(
bxi
->
ql
,
kqsx
);
const
int
ql0
=
(
ql
>>
0
)
&
0x0F0F0F0F
;
const
int
ql1
=
(
ql
>>
4
)
&
0x0F0F0F0F
;
const
int
qh
=
get_int_from_uint8
(
bxi
->
qh
,
(
QI6_K
/
4
)
*
(
kqsx
/
(
QI6_K
/
2
))
+
kqsx
%
(
QI6_K
/
4
));
const
int
qh0
=
((
qh
>>
(
2
*
((
kqsx
%
(
QI6_K
/
2
))
/
(
QI6_K
/
4
))))
<<
4
)
&
0x30303030
;
const
int
qh1
=
(
qh
>>
(
2
*
((
kqsx
%
(
QI6_K
/
2
))
/
(
QI6_K
/
4
))))
&
0x30303030
;
const
int
kq0
=
ky
-
ky
%
QI6_K
+
k
%
(
QI6_K
/
2
)
+
0
;
const
int
kq1
=
ky
-
ky
%
QI6_K
+
k
%
(
QI6_K
/
2
)
+
(
QI6_K
/
2
);
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
kq0
]
=
__vsubss4
(
ql0
|
qh0
,
0x20202020
);
x_ql
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
kq1
]
=
__vsubss4
(
ql1
|
qh1
,
0x20202020
);
}
const
int
blocks_per_tile_x_row
=
WARP_SIZE
/
QI6_K
;
// == 1 if QK_K == 256
const
int
kbxd
=
k
%
blocks_per_tile_x_row
;
// == 0 if QK_K == 256
float
*
x_dmf
=
(
float
*
)
x_dm
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
QI6_K
)
{
int
i
=
(
i0
+
i_offset
*
QI6_K
+
k
/
blocks_per_tile_x_row
)
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
kbxd
;
x_dmf
[
i
*
(
WARP_SIZE
/
QI6_K
)
+
i
/
QI6_K
+
kbxd
]
=
__half2float
(
bxi
->
d
);
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
nwarps
*
8
)
{
int
i
=
(
i0
+
i_offset
*
8
+
k
/
(
WARP_SIZE
/
8
))
%
mmq_y
;
if
(
need_check
)
{
i
=
min
(
i
,
i_max
);
}
const
block_q6_K
*
bxi
=
bx0
+
i
*
blocks_per_row
+
(
k
%
(
WARP_SIZE
/
8
))
/
4
;
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
k
%
(
WARP_SIZE
/
8
)]
=
get_int_from_int8
(
bxi
->
scales
,
k
%
(
QI6_K
/
8
));
}
}
static
__device__
__forceinline__
float
vec_dot_q6_K_q8_1_mul_mat
(
const
int
*
__restrict__
x_ql
,
const
half2
*
__restrict__
x_dm
,
const
int
*
__restrict__
x_qh
,
const
int
*
__restrict__
x_sc
,
const
int
*
__restrict__
y_qs
,
const
half2
*
__restrict__
y_ds
,
const
int
&
i
,
const
int
&
j
,
const
int
&
k
)
{
const
float
*
x_dmf
=
(
const
float
*
)
x_dm
;
const
float
*
y_df
=
(
const
float
*
)
y_ds
;
const
int8_t
*
sc
=
((
const
int8_t
*
)
&
x_sc
[
i
*
(
WARP_SIZE
/
8
)
+
i
/
8
+
k
/
8
]);
const
int
index_x
=
i
*
(
QR6_K
*
WARP_SIZE
+
1
)
+
QR6_K
*
k
;
const
int
index_y
=
j
*
WARP_SIZE
+
(
QR6_K
*
k
)
%
WARP_SIZE
;
return
vec_dot_q6_K_q8_1_impl_mmq
(
&
x_ql
[
index_x
],
&
y_qs
[
index_y
],
sc
,
x_dmf
[
i
*
(
WARP_SIZE
/
QI6_K
)
+
i
/
QI6_K
],
&
y_df
[
index_y
/
QI8_1
]);
}
static
__device__
__forceinline__
float
vec_dot_iq2_xxs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_iq2_xxs
*
bq2
=
(
const
block_iq2_xxs
*
)
vbq
;
const
int
ib32
=
iqs
;
const
uint16_t
*
q2
=
bq2
->
qs
+
4
*
ib32
;
const
uint8_t
*
aux8
=
(
const
uint8_t
*
)
q2
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
uint32_t
aux32
=
q2
[
2
]
|
(
q2
[
3
]
<<
16
);
int
sumi
=
0
;
for
(
int
l
=
0
;
l
<
4
;
++
l
)
{
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xxs_grid
+
aux8
[
l
]);
const
uint8_t
signs
=
ksigns_iq2xs
[
aux32
&
127
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
sumi
+=
q8
[
j
]
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1
:
1
);
}
q8
+=
8
;
aux32
>>=
7
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
(
0.5
f
+
aux32
)
*
__half2float
(
bq8_1
[
ib32
].
ds
.
x
)
*
0.25
f
;
return
d
*
sumi
;
}
static
__device__
__forceinline__
float
vec_dot_iq2_xs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
const
block_iq2_xs
*
bq2
=
(
const
block_iq2_xs
*
)
vbq
;
const
int
ib32
=
iqs
;
const
uint16_t
*
q2
=
bq2
->
qs
+
4
*
ib32
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
const
uint8_t
ls1
=
bq2
->
scales
[
ib32
]
&
0xf
;
const
uint8_t
ls2
=
bq2
->
scales
[
ib32
]
>>
4
;
int
sumi1
=
0
;
for
(
int
l
=
0
;
l
<
2
;
++
l
)
{
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xs_grid
+
(
q2
[
l
]
&
511
));
const
uint8_t
signs
=
ksigns_iq2xs
[
q2
[
l
]
>>
9
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
sumi1
+=
q8
[
j
]
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1
:
1
);
}
q8
+=
8
;
}
int
sumi2
=
0
;
for
(
int
l
=
2
;
l
<
4
;
++
l
)
{
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xs_grid
+
(
q2
[
l
]
&
511
));
const
uint8_t
signs
=
ksigns_iq2xs
[
q2
[
l
]
>>
9
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
sumi2
+=
q8
[
j
]
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1
:
1
);
}
q8
+=
8
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
__half2float
(
bq8_1
[
ib32
].
ds
.
x
)
*
0.25
f
;
return
d
*
((
0.5
f
+
ls1
)
*
sumi1
+
(
0.5
f
+
ls2
)
*
sumi2
);
}
static
__device__
__forceinline__
float
vec_dot_iq2_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq2_s
*
bq2
=
(
const
block_iq2_s
*
)
vbq
;
const
int
ib32
=
iqs
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
const
uint8_t
*
signs
=
bq2
->
qs
+
QK_K
/
8
+
4
*
ib32
;
const
uint8_t
ls1
=
bq2
->
scales
[
ib32
]
&
0xf
;
const
uint8_t
ls2
=
bq2
->
scales
[
ib32
]
>>
4
;
int
sumi1
=
0
;
for
(
int
l
=
0
;
l
<
2
;
++
l
)
{
const
uint32_t
*
grid
=
(
const
uint32_t
*
)(
iq2s_grid
+
(
bq2
->
qs
[
4
*
ib32
+
l
]
|
((
bq2
->
qh
[
ib32
]
<<
(
8
-
2
*
l
))
&
0x300
)));
const
uint32_t
signs0
=
__vcmpeq4
(((
signs
[
l
]
&
0xf
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
uint32_t
signs1
=
__vcmpeq4
(((
signs
[
l
]
>>
4
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
int
grid_l
=
__vsub4
(
grid
[
0
]
^
signs0
,
signs0
);
const
int
grid_h
=
__vsub4
(
grid
[
1
]
^
signs1
,
signs1
);
sumi1
=
__dp4a
(
grid_l
,
*
((
const
int
*
)
q8
+
0
),
sumi1
);
sumi1
=
__dp4a
(
grid_h
,
*
((
const
int
*
)
q8
+
1
),
sumi1
);
q8
+=
8
;
}
int
sumi2
=
0
;
for
(
int
l
=
2
;
l
<
4
;
++
l
)
{
const
uint32_t
*
grid
=
(
const
uint32_t
*
)(
iq2s_grid
+
(
bq2
->
qs
[
4
*
ib32
+
l
]
|
((
bq2
->
qh
[
ib32
]
<<
(
8
-
2
*
l
))
&
0x300
)));
const
uint32_t
signs0
=
__vcmpeq4
(((
signs
[
l
]
&
0xf
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
uint32_t
signs1
=
__vcmpeq4
(((
signs
[
l
]
>>
4
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
int
grid_l
=
__vsub4
(
grid
[
0
]
^
signs0
,
signs0
);
const
int
grid_h
=
__vsub4
(
grid
[
1
]
^
signs1
,
signs1
);
sumi2
=
__dp4a
(
grid_l
,
*
((
const
int
*
)
q8
+
0
),
sumi2
);
sumi2
=
__dp4a
(
grid_h
,
*
((
const
int
*
)
q8
+
1
),
sumi2
);
q8
+=
8
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
__low2float
(
bq8_1
[
ib32
].
ds
)
*
0.25
f
;
return
d
*
((
0.5
f
+
ls1
)
*
sumi1
+
(
0.5
f
+
ls2
)
*
sumi2
);
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq3_xxs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq3_xxs
*
bq2
=
(
const
block_iq3_xxs
*
)
vbq
;
const
int
ib32
=
iqs
;
const
uint8_t
*
q3
=
bq2
->
qs
+
8
*
ib32
;
const
uint16_t
*
gas
=
(
const
uint16_t
*
)(
bq2
->
qs
+
QK_K
/
4
)
+
2
*
ib32
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
uint32_t
aux32
=
gas
[
0
]
|
(
gas
[
1
]
<<
16
);
int
sumi
=
0
;
for
(
int
l
=
0
;
l
<
4
;
++
l
)
{
const
uint32_t
*
grid1
=
iq3xxs_grid
+
q3
[
2
*
l
+
0
];
const
uint32_t
*
grid2
=
iq3xxs_grid
+
q3
[
2
*
l
+
1
];
const
uint32_t
*
signs
=
(
const
uint32_t
*
)(
ksigns64
+
(
aux32
&
127
));
const
int
grid_l
=
__vsub4
(
grid1
[
0
]
^
signs
[
0
],
signs
[
0
]);
const
int
grid_h
=
__vsub4
(
grid2
[
0
]
^
signs
[
1
],
signs
[
1
]);
sumi
=
__dp4a
(
grid_l
,
*
((
int
*
)
q8
+
0
),
sumi
);
sumi
=
__dp4a
(
grid_h
,
*
((
int
*
)
q8
+
1
),
sumi
);
q8
+=
8
;
aux32
>>=
7
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
(
0.5
f
+
aux32
)
*
__low2float
(
bq8_1
[
ib32
].
ds
)
*
0.5
f
;
return
d
*
sumi
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq3_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq3_s
*
bq2
=
(
const
block_iq3_s
*
)
vbq
;
const
int
ib32
=
iqs
;
const
uint8_t
*
qs
=
bq2
->
qs
+
8
*
ib32
;
const
int8_t
*
q8
=
bq8_1
[
ib32
].
qs
;
int
sumi
=
0
;
for
(
int
l
=
0
;
l
<
4
;
++
l
)
{
const
uint32_t
*
grid1
=
iq3xs_grid
+
(
qs
[
2
*
l
+
0
]
|
((
bq2
->
qh
[
ib32
]
<<
(
8
-
2
*
l
))
&
256
));
const
uint32_t
*
grid2
=
iq3xs_grid
+
(
qs
[
2
*
l
+
1
]
|
((
bq2
->
qh
[
ib32
]
<<
(
7
-
2
*
l
))
&
256
));
uint32_t
signs0
=
__vcmpeq4
(((
bq2
->
signs
[
4
*
ib32
+
l
]
&
0xf
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
uint32_t
signs1
=
__vcmpeq4
(((
bq2
->
signs
[
4
*
ib32
+
l
]
>>
4
)
*
0x01010101
)
&
0x08040201
,
0x08040201
);
const
int
grid_l
=
__vsub4
(
grid1
[
0
]
^
signs0
,
signs0
);
const
int
grid_h
=
__vsub4
(
grid2
[
0
]
^
signs1
,
signs1
);
sumi
=
__dp4a
(
grid_l
,
*
((
int
*
)
q8
+
0
),
sumi
);
sumi
=
__dp4a
(
grid_h
,
*
((
int
*
)
q8
+
1
),
sumi
);
q8
+=
8
;
}
const
float
d
=
__half2float
(
bq2
->
d
)
*
(
0.5
f
+
((
bq2
->
scales
[
ib32
/
2
]
>>
4
*
(
ib32
%
2
))
&
0xf
))
*
__low2float
(
bq8_1
[
ib32
].
ds
)
*
0.5
f
;
return
d
*
sumi
;
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq1_s_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq1_s
*
bq1
=
(
const
block_iq1_s
*
)
vbq
;
const
int
qs_packed
=
get_int_b2
(
bq1
->
qs
,
iqs
);
const
uint8_t
*
qs
=
(
const
uint8_t
*
)
&
qs_packed
;
const
int
qh
=
bq1
->
qh
[
iqs
];
int
sumi
=
0
;
#pragma unroll
for
(
int
l0
=
0
;
l0
<
8
;
l0
+=
2
)
{
const
int
grid
=
iq1s_grid_gpu
[
qs
[
l0
/
2
]
|
(((
qh
>>
3
*
(
l0
/
2
))
&
0x07
)
<<
8
)];
const
int
grid0
=
(
grid
>>
0
)
&
0x0F0F0F0F
;
const
int
grid1
=
(
grid
>>
4
)
&
0x0F0F0F0F
;
const
int
u0
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
0
);
const
int
u1
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
1
);
sumi
=
__dp4a
(
grid0
,
u0
,
sumi
);
sumi
=
__dp4a
(
grid1
,
u1
,
sumi
);
}
const
float
d1q
=
__half2float
(
bq1
->
d
)
*
(((
qh
>>
11
)
&
0x0E
)
+
1
);
const
float
delta
=
-
1.0
f
+
IQ1S_DELTA
-
(
qh
&
0x8000
)
*
(
2.0
f
*
IQ1S_DELTA
/
0x8000
);
const
float2
ds
=
__half22float2
(
bq8_1
[
iqs
].
ds
);
return
d1q
*
(
ds
.
x
*
sumi
+
ds
.
y
*
delta
);
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq1_m_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq1_m
*
bq1
=
(
const
block_iq1_m
*
)
vbq
;
const
int
qs_packed
=
get_int_b4
(
bq1
->
qs
,
iqs
);
const
uint8_t
*
qs
=
(
const
uint8_t
*
)
&
qs_packed
;
int
sumi
[
2
]
=
{
0
};
float
sumf
[
2
]
=
{
0.0
f
};
#pragma unroll
for
(
int
l0
=
0
;
l0
<
8
;
l0
+=
2
)
{
const
int
qhl
=
bq1
->
qh
[
2
*
iqs
+
l0
/
4
]
>>
(
4
*
((
l0
/
2
)
%
2
));
const
int
grid
=
iq1s_grid_gpu
[
qs
[
l0
/
2
]
|
((
qhl
&
0x07
)
<<
8
)];
const
int
grid0
=
(
grid
>>
0
)
&
0x0F0F0F0F
;
const
int
grid1
=
(
grid
>>
4
)
&
0x0F0F0F0F
;
const
int
u0
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
0
);
const
int
u1
=
get_int_b4
(
bq8_1
[
iqs
].
qs
,
l0
+
1
);
sumi
[
l0
/
4
]
=
__dp4a
(
grid0
,
u0
,
sumi
[
l0
/
4
]);
sumi
[
l0
/
4
]
=
__dp4a
(
grid1
,
u1
,
sumi
[
l0
/
4
]);
const
float
delta
=
-
1.0
f
+
IQ1M_DELTA
-
(
qhl
&
0x08
)
*
(
2.0
f
*
IQ1M_DELTA
/
0x08
);
int
sumy
=
0
;
sumy
=
__dp4a
(
u0
,
0x01010101
,
sumy
);
sumy
=
__dp4a
(
u1
,
0x01010101
,
sumy
);
sumf
[
l0
/
4
]
+=
delta
*
sumy
;
}
const
uint16_t
*
sc
=
(
const
uint16_t
*
)
bq1
->
scales
;
iq1m_scale_t
scale
;
scale
.
u16
=
(
sc
[
0
]
>>
12
)
|
((
sc
[
1
]
>>
8
)
&
0x00F0
)
|
((
sc
[
2
]
>>
4
)
&
0x0F00
)
|
(
sc
[
3
]
&
0xF000
);
const
float
d
=
__half2float
(
scale
.
f16
)
*
__low2float
(
bq8_1
[
iqs
].
ds
);
const
int
tmp
=
sc
[
iqs
/
2
]
>>
(
6
*
(
iqs
%
2
));
const
int
sc0
=
2
*
((
tmp
>>
0
)
&
0x07
)
+
1
;
const
int
sc1
=
2
*
((
tmp
>>
3
)
&
0x07
)
+
1
;
return
d
*
((
sumi
[
0
]
+
sumf
[
0
])
*
sc0
+
(
sumi
[
1
]
+
sumf
[
1
])
*
sc1
);
#endif
}
static
__device__
__forceinline__
void
get_int_from_table_16
(
const
uint32_t
&
q4
,
const
uint8_t
*
values
,
int
&
val1
,
int
&
val2
)
{
uint32_t
aux32
;
const
uint8_t
*
q8
=
(
const
uint8_t
*
)
&
aux32
;
aux32
=
q4
&
0x0f0f0f0f
;
uint16_t
v1
=
values
[
q8
[
0
]]
|
(
values
[
q8
[
1
]]
<<
8
);
uint16_t
v2
=
values
[
q8
[
2
]]
|
(
values
[
q8
[
3
]]
<<
8
);
val1
=
v1
|
(
v2
<<
16
);
aux32
=
(
q4
>>
4
)
&
0x0f0f0f0f
;
v1
=
values
[
q8
[
0
]]
|
(
values
[
q8
[
1
]]
<<
8
);
v2
=
values
[
q8
[
2
]]
|
(
values
[
q8
[
3
]]
<<
8
);
val2
=
v1
|
(
v2
<<
16
);
}
static
__device__
__forceinline__
float
vec_dot_iq4_nl_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq4_nl
*
bq
=
(
const
block_iq4_nl
*
)
vbq
;
const
uint16_t
*
q4
=
(
const
uint16_t
*
)
bq
->
qs
+
2
*
iqs
;
const
int32_t
*
q8
=
(
const
int32_t
*
)
bq8_1
->
qs
+
iqs
;
const
uint8_t
*
values
=
(
const
uint8_t
*
)
kvalues_iq4nl
;
int
v1
,
v2
;
int
sumi1
=
0
,
sumi2
=
0
;
for
(
int
l
=
0
;
l
<
VDR_Q4_0_Q8_1_MMVQ
;
++
l
)
{
const
uint32_t
aux
=
q4
[
2
*
l
]
|
(
q4
[
2
*
l
+
1
]
<<
16
);
get_int_from_table_16
(
aux
,
values
,
v1
,
v2
);
sumi1
=
__dp4a
(
v1
,
q8
[
l
+
0
],
sumi1
);
sumi2
=
__dp4a
(
v2
,
q8
[
l
+
4
],
sumi2
);
}
const
float
d
=
__half2float
(
bq
->
d
)
*
__low2float
(
bq8_1
->
ds
);
return
d
*
(
sumi1
+
sumi2
);
#endif
}
static
__device__
__forceinline__
float
vec_dot_iq4_xs_q8_1
(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
const
block_iq4_xs
*
bq4
=
(
const
block_iq4_xs
*
)
vbq
;
const
uint8_t
*
values
=
(
const
uint8_t
*
)
kvalues_iq4nl
;
// iqs is 0...7
const
int
ib32
=
iqs
;
const
int32_t
*
q8
=
(
const
int
*
)
bq8_1
[
ib32
].
qs
;
const
uint32_t
*
q4
=
(
const
uint32_t
*
)
bq4
->
qs
+
4
*
ib32
;
const
int8_t
ls
=
((
bq4
->
scales_l
[
ib32
/
2
]
>>
4
*
(
ib32
%
2
))
&
0xf
)
|
(((
bq4
->
scales_h
>>
2
*
ib32
)
&
3
)
<<
4
);
const
float
d
=
__half2float
(
bq4
->
d
)
*
(
ls
-
32
)
*
__low2float
(
bq8_1
[
ib32
].
ds
);
int
v1
,
v2
;
int
sumi1
=
0
,
sumi2
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
get_int_from_table_16
(
q4
[
j
],
values
,
v1
,
v2
);
sumi1
=
__dp4a
(
v1
,
q8
[
j
+
0
],
sumi1
);
sumi2
=
__dp4a
(
v2
,
q8
[
j
+
4
],
sumi2
);
}
return
d
*
(
sumi1
+
sumi2
);
#endif
}
\ No newline at end of file
Prev
1
…
4
5
6
7
8
9
10
11
12
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment