Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
7e1fe256
Commit
7e1fe256
authored
Feb 21, 2025
by
Atream
Browse files
optimize GPU
parent
cf4da5fd
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
686 additions
and
165 deletions
+686
-165
ktransformers/ktransformers_ext/cuda/binding.cpp
ktransformers/ktransformers_ext/cuda/binding.cpp
+7
-7
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+11
-11
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+599
-100
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+7
-7
ktransformers/local_chat.py
ktransformers/local_chat.py
+1
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+12
-0
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+1
-1
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+48
-35
No files found.
ktransformers/ktransformers_ext/cuda/binding.cpp
View file @
7e1fe256
...
@@ -20,19 +20,19 @@
...
@@ -20,19 +20,19 @@
PYBIND11_MODULE
(
KTransformersOps
,
m
)
{
PYBIND11_MODULE
(
KTransformersOps
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
View file @
7e1fe256
...
@@ -17,19 +17,19 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -17,19 +17,19 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
PYBIND11_MODULE
(
cudaops
,
m
)
{
PYBIND11_MODULE
(
cudaops
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
}
}
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
7e1fe256
...
@@ -10,19 +10,53 @@
...
@@ -10,19 +10,53 @@
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*/
*/
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/library.h>
#include <torch/library.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <torch/torch.h>
#include <cstdint>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
__global__
void
dequantize_q8_0_
fp32_
kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
int
i
=
0
;
i
<
blk_size
;
i
++
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
scale
=
scales
[
block_id
];
const
int8_t
*
cur_block
=
data
+
block_id
*
blk_size
;
output
[
block_id
*
blk_size
+
i
]
=
scale
*
qs
[
block_id
*
blk_size
+
i
];
float
scale
=
__half2float
(
*
((
half
*
)
cur_block
));
cur_block
+=
2
;
for
(
int
i
=
0
;
i
<
32
;
i
++
){
output_blk
[
i
]
=
scale
*
cur_block
[
i
];
}
output_blk
+=
32
;
}
}
__global__
void
dequantize_q8_0_fp16_kernel
(
const
int8_t
*
data
,
__half
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
__half
*
__restrict__
output_blk
=
(
__half
*
)(
output
+
block_id
*
256
);
const
int8_t
*
cur_block
=
data
+
block_id
*
blk_size
;
float
scale
=
__half2float
(
*
((
half
*
)
cur_block
));
cur_block
+=
2
;
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
output_blk
[
i
]
=
__float2half
(
scale
*
cur_block
[
i
]);
}
}
output_blk
+=
32
;
}
}
__global__
void
dequantize_q8_0_bf16_kernel
(
const
int8_t
*
data
,
nv_bfloat16
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
nv_bfloat16
*
__restrict__
output_blk
=
(
nv_bfloat16
*
)(
output
+
block_id
*
256
);
const
int8_t
*
cur_block
=
data
+
block_id
*
blk_size
;
float
scale
=
__half2float
(
*
((
half
*
)
cur_block
));
cur_block
+=
2
;
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
output_blk
[
i
]
=
__float2bfloat16
(
scale
*
cur_block
[
i
]);
}
output_blk
+=
32
;
}
}
}
}
...
@@ -36,13 +70,13 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
...
@@ -36,13 +70,13 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
}
}
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q2_k_
fp32_
kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
82
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
82
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
...
@@ -70,7 +104,75 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -70,7 +104,75 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
}
}
}
}
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q2_k_fp16_kernel
(
const
int8_t
*
data
,
__half
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
__half
*
__restrict__
output_blk
=
(
__half
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
82
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
int
is
=
0
;
float
dl
,
ml
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
uint8_t
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
__float2half
(
dl
*
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
))
-
ml
);
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
__float2half
(
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
))
-
ml
);
shift
+=
2
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q2_k_bf16_kernel
(
const
int8_t
*
data
,
nv_bfloat16
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
nv_bfloat16
*
__restrict__
output_blk
=
(
nv_bfloat16
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
82
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
int
is
=
0
;
float
dl
,
ml
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
uint8_t
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
__float2bfloat16
(
dl
*
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
))
-
ml
);
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
__float2bfloat16
(
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
))
-
ml
);
shift
+=
2
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q3_k_fp32_kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask1
=
0x03030303
;
...
@@ -80,7 +182,7 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -80,7 +182,7 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
uint32_t
aux
[
4
];
uint32_t
aux
[
4
];
const
int8_t
*
scales
=
(
const
int8_t
*
)
aux
;
const
int8_t
*
scales
=
(
const
int8_t
*
)
aux
;
const
float
d_all
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
108
)));
const
float
d_all
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
108
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
32
);
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
32
);
const
uint8_t
*
__restrict__
hm
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
0
);
const
uint8_t
*
__restrict__
hm
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
0
);
...
@@ -126,16 +228,128 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -126,16 +228,128 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
}
}
}
}
__global__
void
dequantize_q3_k_fp16_kernel
(
const
int8_t
*
data
,
__half
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
__half
*
__restrict__
output_blk
=
(
__half
*
)(
output
+
block_id
*
256
);
uint32_t
aux
[
4
];
const
int8_t
*
scales
=
(
const
int8_t
*
)
aux
;
const
float
d_all
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
108
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
32
);
const
uint8_t
*
__restrict__
hm
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
0
);
uint8_t
m
=
1
;
uint8_t
*
block_scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
96
);
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
aux
[
i
]
=
0
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
aux
[
i
]
|=
((
uint32_t
)
block_scales
[
i
*
4
+
j
])
<<
(
j
*
8
);
}
}
uint32_t
tmp
=
aux
[
2
];
aux
[
2
]
=
((
aux
[
0
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
4
)
&
kmask1
)
<<
4
);
aux
[
3
]
=
((
aux
[
1
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
6
)
&
kmask1
)
<<
4
);
aux
[
0
]
=
(
aux
[
0
]
&
kmask2
)
|
(((
tmp
>>
0
)
&
kmask1
)
<<
4
);
aux
[
1
]
=
(
aux
[
1
]
&
kmask2
)
|
(((
tmp
>>
2
)
&
kmask1
)
<<
4
);
int
is
=
0
;
float
dl
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
__float2half
(
dl
*
((
int8_t
)((
q
[
l
+
0
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
0
]
&
m
)
?
0
:
4
)));
}
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
__float2half
(
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
16
]
&
m
)
?
0
:
4
)));
}
shift
+=
2
;
m
<<=
1
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q3_k_bf16_kernel
(
const
int8_t
*
data
,
nv_bfloat16
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
nv_bfloat16
*
__restrict__
output_blk
=
(
nv_bfloat16
*
)(
output
+
block_id
*
256
);
uint32_t
aux
[
4
];
const
int8_t
*
scales
=
(
const
int8_t
*
)
aux
;
const
float
d_all
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
108
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
32
);
const
uint8_t
*
__restrict__
hm
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
0
);
uint8_t
m
=
1
;
uint8_t
*
block_scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
96
);
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
aux
[
i
]
=
0
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
aux
[
i
]
|=
((
uint32_t
)
block_scales
[
i
*
4
+
j
])
<<
(
j
*
8
);
}
}
uint32_t
tmp
=
aux
[
2
];
aux
[
2
]
=
((
aux
[
0
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
4
)
&
kmask1
)
<<
4
);
aux
[
3
]
=
((
aux
[
1
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
6
)
&
kmask1
)
<<
4
);
aux
[
0
]
=
(
aux
[
0
]
&
kmask2
)
|
(((
tmp
>>
0
)
&
kmask1
)
<<
4
);
aux
[
1
]
=
(
aux
[
1
]
&
kmask2
)
|
(((
tmp
>>
2
)
&
kmask1
)
<<
4
);
int
is
=
0
;
float
dl
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
__float2bfloat16
(
dl
*
((
int8_t
)((
q
[
l
+
0
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
0
]
&
m
)
?
0
:
4
)));
}
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
__float2bfloat16
(
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
16
]
&
m
)
?
0
:
4
)));
}
shift
+=
2
;
m
<<=
1
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q4_k_
fp32_
kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
// const uint8_t * q = data[i].qs;
// const uint8_t * q = data[i].qs;
const
uint8_t
*
q
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
16
);
const
uint8_t
*
q
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
16
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
144
+
0
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
144
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
144
+
2
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
144
+
2
)));
int
is
=
0
;
int
is
=
0
;
uint8_t
sc
,
m
;
uint8_t
sc
,
m
;
for
(
int
j
=
0
;
j
<
blk_size
;
j
+=
64
)
{
for
(
int
j
=
0
;
j
<
blk_size
;
j
+=
64
)
{
...
@@ -151,13 +365,61 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -151,13 +365,61 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
}
}
__global__
void
dequantize_q5_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q4_k_fp16_kernel
(
const
int8_t
*
data
,
__half
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
__half
*
__restrict__
output_blk
=
(
__half
*
)(
output
+
block_id
*
256
);
// const uint8_t * q = data[i].qs;
const
uint8_t
*
q
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
16
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
144
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
144
+
2
)));
int
is
=
0
;
uint8_t
sc
,
m
;
for
(
int
j
=
0
;
j
<
blk_size
;
j
+=
64
)
{
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
4
);
get_scale_min_k4
(
is
+
0
,
scales
,
&
sc
,
&
m
);
const
float
d1
=
d
*
sc
;
const
float
m1
=
min
*
m
;
get_scale_min_k4
(
is
+
1
,
scales
,
&
sc
,
&
m
);
const
float
d2
=
d
*
sc
;
const
float
m2
=
min
*
m
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
__float2half
(
d1
*
(
q
[
l
]
&
0xF
)
-
m1
);
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
__float2half
(
d2
*
(
q
[
l
]
>>
4
)
-
m2
);
q
+=
32
;
is
+=
2
;
}
}
}
__global__
void
dequantize_q4_k_bf16_kernel
(
const
int8_t
*
data
,
nv_bfloat16
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
nv_bfloat16
*
__restrict__
output_blk
=
(
nv_bfloat16
*
)(
output
+
block_id
*
256
);
// const uint8_t * q = data[i].qs;
const
uint8_t
*
q
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
16
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
144
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
144
+
2
)));
int
is
=
0
;
uint8_t
sc
,
m
;
for
(
int
j
=
0
;
j
<
blk_size
;
j
+=
64
)
{
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
4
);
get_scale_min_k4
(
is
+
0
,
scales
,
&
sc
,
&
m
);
const
float
d1
=
d
*
sc
;
const
float
m1
=
min
*
m
;
get_scale_min_k4
(
is
+
1
,
scales
,
&
sc
,
&
m
);
const
float
d2
=
d
*
sc
;
const
float
m2
=
min
*
m
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
__float2bfloat16
(
d1
*
(
q
[
l
]
&
0xF
)
-
m1
);
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
__float2bfloat16
(
d2
*
(
q
[
l
]
>>
4
)
-
m2
);
q
+=
32
;
is
+=
2
;
}
}
}
__global__
void
dequantize_q5_k_fp32_kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
2
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
2
)));
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
48
);
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
48
);
...
@@ -180,11 +442,69 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -180,11 +442,69 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
}
}
}
}
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q5_k_fp16_kernel
(
const
int8_t
*
data
,
__half
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
__half
*
__restrict__
output_blk
=
(
__half
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
2
)));
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
48
);
int
is
=
0
;
uint8_t
sc
,
m
;
uint8_t
u1
=
1
,
u2
=
2
;
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
4
);
for
(
int
j
=
0
;
j
<
256
;
j
+=
64
)
{
get_scale_min_k4
(
is
+
0
,
scales
,
&
sc
,
&
m
);
const
float
d1
=
d
*
sc
;
const
float
m1
=
min
*
m
;
get_scale_min_k4
(
is
+
1
,
scales
,
&
sc
,
&
m
);
const
float
d2
=
d
*
sc
;
const
float
m2
=
min
*
m
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
__float2half
(
d1
*
((
ql
[
l
]
&
0xF
)
+
(
qh
[
l
]
&
u1
?
16
:
0
))
-
m1
);
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
__float2half
(
d2
*
((
ql
[
l
]
>>
4
)
+
(
qh
[
l
]
&
u2
?
16
:
0
))
-
m2
);
ql
+=
32
;
is
+=
2
;
u1
<<=
2
;
u2
<<=
2
;
}
}
}
__global__
void
dequantize_q5_k_bf16_kernel
(
const
int8_t
*
data
,
nv_bfloat16
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
nv_bfloat16
*
__restrict__
output_blk
=
(
nv_bfloat16
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
2
)));
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
48
);
int
is
=
0
;
uint8_t
sc
,
m
;
uint8_t
u1
=
1
,
u2
=
2
;
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
4
);
for
(
int
j
=
0
;
j
<
256
;
j
+=
64
)
{
get_scale_min_k4
(
is
+
0
,
scales
,
&
sc
,
&
m
);
const
float
d1
=
d
*
sc
;
const
float
m1
=
min
*
m
;
get_scale_min_k4
(
is
+
1
,
scales
,
&
sc
,
&
m
);
const
float
d2
=
d
*
sc
;
const
float
m2
=
min
*
m
;
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
__float2bfloat16
(
d1
*
((
ql
[
l
]
&
0xF
)
+
(
qh
[
l
]
&
u1
?
16
:
0
))
-
m1
);
for
(
int
l
=
0
;
l
<
32
;
++
l
)
*
output_blk
++
=
__float2bfloat16
(
d2
*
((
ql
[
l
]
>>
4
)
+
(
qh
[
l
]
&
u2
?
16
:
0
))
-
m2
);
ql
+=
32
;
is
+=
2
;
u1
<<=
2
;
u2
<<=
2
;
}
}
}
__global__
void
dequantize_q6_k_fp32_kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
208
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
208
)));
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
);
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
);
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
128
);
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
128
);
...
@@ -212,14 +532,78 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -212,14 +532,78 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
}
}
}
}
__global__
void
dequantize_q6_k_fp16_kernel
(
const
int8_t
*
data
,
__half
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
__half
*
__restrict__
output_blk
=
(
__half
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
208
)));
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
);
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
128
);
const
int8_t
*
__restrict__
sc
=
(
int8_t
*
)(
data
+
block_id
*
blk_size
+
192
);
//if (blk_size == 256){
for
(
int
n
=
0
;
n
<
blk_size
;
n
+=
128
)
{
for
(
int
l
=
0
;
l
<
32
;
++
l
)
{
int
is
=
l
/
16
;
const
int8_t
q1
=
(
int8_t
)((
ql
[
l
+
0
]
&
0xF
)
|
(((
qh
[
l
]
>>
0
)
&
3
)
<<
4
))
-
32
;
const
int8_t
q2
=
(
int8_t
)((
ql
[
l
+
32
]
&
0xF
)
|
(((
qh
[
l
]
>>
2
)
&
3
)
<<
4
))
-
32
;
const
int8_t
q3
=
(
int8_t
)((
ql
[
l
+
0
]
>>
4
)
|
(((
qh
[
l
]
>>
4
)
&
3
)
<<
4
))
-
32
;
const
int8_t
q4
=
(
int8_t
)((
ql
[
l
+
32
]
>>
4
)
|
(((
qh
[
l
]
>>
6
)
&
3
)
<<
4
))
-
32
;
output_blk
[
l
+
0
]
=
__float2half
(
d
*
sc
[
is
+
0
]
*
q1
);
output_blk
[
l
+
32
]
=
__float2half
(
d
*
sc
[
is
+
2
]
*
q2
);
output_blk
[
l
+
64
]
=
__float2half
(
d
*
sc
[
is
+
4
]
*
q3
);
output_blk
[
l
+
96
]
=
__float2half
(
d
*
sc
[
is
+
6
]
*
q4
);
}
output_blk
+=
128
;
ql
+=
64
;
qh
+=
32
;
sc
+=
8
;
}
}
}
__global__
void
dequantize_q6_k_bf16_kernel
(
const
int8_t
*
data
,
nv_bfloat16
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
nv_bfloat16
*
__restrict__
output_blk
=
(
nv_bfloat16
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
+
208
)));
const
uint8_t
*
__restrict__
ql
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
);
const
uint8_t
*
__restrict__
qh
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
128
);
const
int8_t
*
__restrict__
sc
=
(
int8_t
*
)(
data
+
block_id
*
blk_size
+
192
);
//if (blk_size == 256){
for
(
int
n
=
0
;
n
<
blk_size
;
n
+=
128
)
{
for
(
int
l
=
0
;
l
<
32
;
++
l
)
{
int
is
=
l
/
16
;
const
int8_t
q1
=
(
int8_t
)((
ql
[
l
+
0
]
&
0xF
)
|
(((
qh
[
l
]
>>
0
)
&
3
)
<<
4
))
-
32
;
const
int8_t
q2
=
(
int8_t
)((
ql
[
l
+
32
]
&
0xF
)
|
(((
qh
[
l
]
>>
2
)
&
3
)
<<
4
))
-
32
;
const
int8_t
q3
=
(
int8_t
)((
ql
[
l
+
0
]
>>
4
)
|
(((
qh
[
l
]
>>
4
)
&
3
)
<<
4
))
-
32
;
const
int8_t
q4
=
(
int8_t
)((
ql
[
l
+
32
]
>>
4
)
|
(((
qh
[
l
]
>>
6
)
&
3
)
<<
4
))
-
32
;
output_blk
[
l
+
0
]
=
__float2bfloat16
(
d
*
sc
[
is
+
0
]
*
q1
);
output_blk
[
l
+
32
]
=
__float2bfloat16
(
d
*
sc
[
is
+
2
]
*
q2
);
output_blk
[
l
+
64
]
=
__float2bfloat16
(
d
*
sc
[
is
+
4
]
*
q3
);
output_blk
[
l
+
96
]
=
__float2bfloat16
(
d
*
sc
[
is
+
6
]
*
q4
);
}
output_blk
+=
128
;
ql
+=
64
;
qh
+=
32
;
sc
+=
8
;
}
}
}
static
constexpr
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
static
constexpr
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
__global__
void
dequantize_iq4_xs_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_iq4_xs_
fp32_
kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
)));
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
const
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
const
uint8_t
*
scales_l
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
);
const
uint8_t
*
scales_l
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
);
const
uint8_t
*
qs
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
+
4
);
const
uint8_t
*
qs
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
+
4
);
...
@@ -236,152 +620,267 @@ __global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_si
...
@@ -236,152 +620,267 @@ __global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_si
}
}
}
}
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
__global__
void
dequantize_iq4_xs_fp16_kernel
(
const
int8_t
*
data
,
__half
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
// create gpu
__half
*
__restrict__
output_blk
=
(
__half
*
)(
output
+
block_id
*
256
);
auto
options_scales
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
)));
auto
options_qs
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
const
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
auto
scales_gpu
=
torch
::
empty
({{
num_blocks
,
1
}},
options_scales
);
const
uint8_t
*
scales_l
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
);
auto
qs_gpu
=
torch
::
empty
({
num_blocks
,
32
},
options_qs
);
const
uint8_t
*
qs
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
+
4
);
for
(
int
ib
=
0
;
ib
<
8
;
++
ib
)
{
const
int
ls
=
((
scales_l
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
)
|
(((
scales_h
>>
2
*
ib
)
&
3
)
<<
4
);
const
float
dl
=
d
*
(
ls
-
32
);
for
(
int
j
=
0
;
j
<
16
;
++
j
)
{
output_blk
[
j
+
0
]
=
__float2half
(
dl
*
kvalues_iq4nl
[
qs
[
j
]
&
0xf
]);
output_blk
[
j
+
16
]
=
__float2half
(
dl
*
kvalues_iq4nl
[
qs
[
j
]
>>
4
]);
}
output_blk
+=
32
;
qs
+=
16
;
}
}
}
// read on cpu
__global__
void
dequantize_iq4_xs_bf16_kernel
(
const
int8_t
*
data
,
nv_bfloat16
*
output
,
const
int
blk_size
,
const
int
num_blocks
)
{
options_scales
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCPU
);
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
options_qs
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
torch
::
kCPU
);
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
nv_bfloat16
*
__restrict__
output_blk
=
(
nv_bfloat16
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
const
half
*>
(
data
+
block_id
*
blk_size
)));
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
const
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
const
uint8_t
*
scales_l
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
);
const
uint8_t
*
qs
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
+
4
);
// // reinterpret
for
(
int
ib
=
0
;
ib
<
8
;
++
ib
)
{
auto
scales
=
torch
::
from_blob
(
data
.
data_ptr
(),
{
num_blocks
,
1
+
16
},
options_scales
).
slice
(
1
,
0
,
1
);
const
int
ls
=
((
scales_l
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
)
|
(((
scales_h
>>
2
*
ib
)
&
3
)
<<
4
);
auto
qs
=
torch
::
from_blob
(
data
.
data_ptr
(),
{
num_blocks
,
2
+
32
},
options_qs
).
slice
(
1
,
2
);
const
float
dl
=
d
*
(
ls
-
32
);
for
(
int
j
=
0
;
j
<
16
;
++
j
)
{
output_blk
[
j
+
0
]
=
__float2bfloat16
(
dl
*
kvalues_iq4nl
[
qs
[
j
]
&
0xf
]);
output_blk
[
j
+
16
]
=
__float2bfloat16
(
dl
*
kvalues_iq4nl
[
qs
[
j
]
>>
4
]);
}
output_blk
+=
32
;
qs
+=
16
;
}
}
}
auto
scales_f32
=
scales
.
to
(
torch
::
kFloat32
);
torch
::
Tensor
dequantize_q8_0
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
)
{
scales_gpu
.
copy_
(
scales_f32
,
false
)
;
int
num_blocks
=
num_bytes
/
blk_size
;
qs_gpu
.
copy_
(
qs
,
fals
e
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
devic
e
);
// Create output tensor
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
output
=
torch
::
zeros_like
(
qs
,
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
auto
data_gpu
=
torch
::
empty
({
num_bytes
},
options
);
cudaMemcpy
(
data_gpu
.
data_ptr
<
int8_t
>
(),
data
,
num_bytes
,
cudaMemcpyHostToDevice
);
//data_gpu.copy_(data, false);
// Launch kernel
// Create output tensor
dequantize_q8_0_kernel
<<<
512
,
256
>>>
(
auto
output
=
torch
::
zeros
({
num_blocks
,
32
},
torch
::
dtype
(
target_dtype
).
device
(
device
));
output
.
data_ptr
<
float
>
(),
scales_gpu
.
data_ptr
<
float
>
(),
qs_gpu
.
data_ptr
<
int8_t
>
(),
num_blocks
,
32
);
switch
(
target_dtype
)
{
case
torch
::
kFloat16
:
dequantize_q8_0_fp16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
(
__half
*
)
output
.
data_ptr
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kBFloat16
:
dequantize_q8_0_bf16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
nv_bfloat16
>
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kFloat32
:
dequantize_q8_0_fp32_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
break
;
default:
printf
(
"target type not support
\n
"
);
exit
(
0
);
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q6_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
)
{
// data.numel%blk_size should be 0, else raise err
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
num_bytes
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()
},
options
);
auto
data_gpu
=
torch
::
empty
({
num_bytes
},
options
);
data_gpu
.
copy_
(
data
,
false
);
cudaMemcpy
(
data_gpu
.
data_ptr
<
int8_t
>
(),
data
,
num_bytes
,
cudaMemcpyHostToDevice
);
//data_gpu.copy_(data, false);
// Create output tensor
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
target_dtype
).
device
(
device
));
// Launch kernel
switch
(
target_dtype
)
{
dequantize_q6_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
case
torch
::
kFloat16
:
// dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), 256, num_blocks);
dequantize_q6_k_fp16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
(
__half
*
)
output
.
data_ptr
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kBFloat16
:
dequantize_q6_k_bf16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
nv_bfloat16
>
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kFloat32
:
dequantize_q6_k_fp32_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
break
;
default:
printf
(
"target type not support
\n
"
);
exit
(
0
);
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q5_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
num_bytes
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()
},
options
);
auto
data_gpu
=
torch
::
empty
({
num_bytes
},
options
);
data_gpu
.
copy_
(
data
,
false
);
cudaMemcpy
(
data_gpu
.
data_ptr
<
int8_t
>
(),
data
,
num_bytes
,
cudaMemcpyHostToDevice
);
//data_gpu.copy_(data, false);
// Create output tensor
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
target_dtype
).
device
(
device
));
// Launch kernel
switch
(
target_dtype
)
{
dequantize_q5_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
case
torch
::
kFloat16
:
dequantize_q5_k_fp16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
(
__half
*
)
output
.
data_ptr
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kBFloat16
:
dequantize_q5_k_bf16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
nv_bfloat16
>
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kFloat32
:
dequantize_q5_k_fp32_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
break
;
default:
printf
(
"target type not support
\n
"
);
exit
(
0
);
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q4_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
)
{
// data.numel%blk_size should be 0, else raise err
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
num_bytes
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()
},
options
);
auto
data_gpu
=
torch
::
empty
({
num_bytes
},
options
);
data_gpu
.
copy_
(
data
,
false
);
cudaMemcpy
(
data_gpu
.
data_ptr
<
int8_t
>
(),
data
,
num_bytes
,
cudaMemcpyHostToDevice
);
//data_gpu.copy_(data, false);
// Create output tensor
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
target_dtype
).
device
(
device
));
// Launch kernel
switch
(
target_dtype
)
{
dequantize_q4_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
256
,
num_blocks
);
case
torch
::
kFloat16
:
dequantize_q4_k_fp16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
(
__half
*
)
output
.
data_ptr
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kBFloat16
:
dequantize_q4_k_bf16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
nv_bfloat16
>
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kFloat32
:
dequantize_q4_k_fp32_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
break
;
default:
printf
(
"target type not support
\n
"
);
exit
(
0
);
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q3_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
num_bytes
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()
},
options
);
auto
data_gpu
=
torch
::
empty
({
num_bytes
},
options
);
data_gpu
.
copy_
(
data
,
false
);
cudaMemcpy
(
data_gpu
.
data_ptr
<
int8_t
>
(),
data
,
num_bytes
,
cudaMemcpyHostToDevice
);
//data_gpu.copy_(data, false);
// Create output tensor
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
target_dtype
).
device
(
device
));
// Launch kernel
switch
(
target_dtype
)
{
dequantize_q3_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
case
torch
::
kFloat16
:
dequantize_q3_k_fp16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
(
__half
*
)
output
.
data_ptr
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kBFloat16
:
dequantize_q3_k_bf16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
nv_bfloat16
>
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kFloat32
:
dequantize_q3_k_fp32_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
break
;
default:
printf
(
"target type not support
\n
"
);
exit
(
0
);
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q2_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
num_bytes
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()
},
options
);
auto
data_gpu
=
torch
::
empty
({
num_bytes
},
options
);
data_gpu
.
copy_
(
data
,
false
);
cudaMemcpy
(
data_gpu
.
data_ptr
<
int8_t
>
(),
data
,
num_bytes
,
cudaMemcpyHostToDevice
);
//data_gpu.copy_(data, false);
// Create output tensor
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
target_dtype
).
device
(
device
));
// Launch kernel
switch
(
target_dtype
)
{
dequantize_q2_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
case
torch
::
kFloat16
:
dequantize_q2_k_fp16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
(
__half
*
)
output
.
data_ptr
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kBFloat16
:
dequantize_q2_k_bf16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
nv_bfloat16
>
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kFloat32
:
dequantize_q2_k_fp32_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
break
;
default:
printf
(
"target type not support
\n
"
);
exit
(
0
);
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_iq4_xs
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_iq4_xs
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
num_bytes
/
blk_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()
},
options
);
auto
data_gpu
=
torch
::
empty
({
num_bytes
},
options
);
data_gpu
.
copy_
(
data
,
false
);
cudaMemcpy
(
data_gpu
.
data_ptr
<
int8_t
>
(),
data
,
num_bytes
,
cudaMemcpyHostToDevice
);
//data_gpu.copy_(data, false);
// Create output tensor
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
target_dtype
).
device
(
device
));
// Launch kernel
switch
(
target_dtype
)
{
dequantize_iq4_xs_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
case
torch
::
kFloat16
:
dequantize_iq4_xs_fp16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
(
__half
*
)
output
.
data_ptr
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kBFloat16
:
dequantize_iq4_xs_bf16_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
nv_bfloat16
>
(),
blk_size
,
num_blocks
);
break
;
case
torch
::
kFloat32
:
dequantize_iq4_xs_fp32_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
break
;
default:
printf
(
"target type not support
\n
"
);
exit
(
0
);
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
7e1fe256
...
@@ -13,10 +13,10 @@
...
@@ -13,10 +13,10 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <torch/torch.h>
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q8_0
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q3_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_iq4_xs
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_iq4_xs
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
ktransformers/local_chat.py
View file @
7e1fe256
...
@@ -168,9 +168,6 @@ def local_chat(
...
@@ -168,9 +168,6 @@ def local_chat(
if
mode
==
'long_context'
:
if
mode
==
'long_context'
:
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
"please change max_seq_len in ~/.ktransformers/config.yaml"
torch
.
set_default_dtype
(
torch
.
bfloat16
)
# TODO: Remove this, replace dtype using config
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
:
generated
=
prefill_and_generate
(
generated
=
prefill_and_generate
(
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
7e1fe256
...
@@ -5,6 +5,18 @@
...
@@ -5,6 +5,18 @@
kwargs
:
kwargs
:
generate_device
:
"
cuda"
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
7e1fe256
...
@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
...
@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
class
KTransformersInterface
(
TransformersInterface
):
class
KTransformersInterface
(
TransformersInterface
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
):
self
.
args
=
args
self
.
args
=
args
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
config
.
_attn_implementation
=
"flash_attention_2"
...
...
ktransformers/util/custom_gguf.py
View file @
7e1fe256
...
@@ -285,7 +285,7 @@ class GGUFLoader:
...
@@ -285,7 +285,7 @@ class GGUFLoader:
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
def
load_expert_tensor
(
self
,
name
,
data
,
expert_id
,
elements_per_expert
,
device
=
"
gpu"
)
->
torch
.
Tensor
:
def
load_expert_tensor
(
self
,
name
,
data
,
expert_id
,
elements_per_expert
,
device
=
"
cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
if
device
.
lower
()
==
"cpu"
:
print
(
f
"loading expert
{
expert_id
}
of
{
name
}
with CPU"
)
print
(
f
"loading expert
{
expert_id
}
of
{
name
}
with CPU"
)
...
@@ -304,7 +304,7 @@ class GGUFLoader:
...
@@ -304,7 +304,7 @@ class GGUFLoader:
data
=
data
[
offset
:
offset
+
block_size
*
blocks_per_experts
]
data
=
data
[
offset
:
offset
+
block_size
*
blocks_per_experts
]
if
"cuda"
in
device
.
lower
():
if
"cuda"
in
device
.
lower
():
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
,
target_dtype
)
else
:
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
values
=
torch
.
from_numpy
(
values
)
...
@@ -313,7 +313,7 @@ class GGUFLoader:
...
@@ -313,7 +313,7 @@ class GGUFLoader:
return
values
return
values
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
)
->
torch
.
Tensor
:
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
,
target_dtype
=
torch
.
get_default_dtype
()
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
if
device
.
lower
()
==
"cpu"
:
print
(
f
"loading
{
name
}
with CPU"
)
print
(
f
"loading
{
name
}
with CPU"
)
...
@@ -328,16 +328,36 @@ class GGUFLoader:
...
@@ -328,16 +328,36 @@ class GGUFLoader:
data
=
self
.
get_mmap_tensor
(
name
)
data
=
self
.
get_mmap_tensor
(
name
)
block_size
=
GGML_BLOCK_SIZES
[
ggml_name
]
elements_per_block
=
GGML_ELEMENTS_PER_BLOCK
[
ggml_name
]
num_elements
=
int
(
np
.
prod
(
shape
))
num_blocks
=
num_elements
//
elements_per_block
blocks_per_iter
=
16384
if
num_blocks
>
blocks_per_iter
:
# dequant large tensor
values
=
torch
.
empty
((
num_blocks
,
elements_per_block
),
dtype
=
torch
.
float
,
device
=
device
)
for
i
in
range
(
(
num_blocks
+
blocks_per_iter
-
1
)
//
blocks_per_iter
):
blocks_begin
=
i
*
blocks_per_iter
blocks_end
=
min
(
blocks_begin
+
blocks_per_iter
,
num_blocks
)
if
"cuda"
in
device
.
lower
():
cur_values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
[
blocks_begin
*
block_size
:
blocks_end
*
block_size
],
device
,
target_dtype
)
else
:
cur_values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
[
blocks_begin
*
block_size
:
blocks_end
*
block_size
])
cur_values
=
torch
.
from_numpy
(
cur_values
)
cur_values
=
cur_values
.
view
(
-
1
,
elements_per_block
)
values
[
blocks_begin
:
blocks_end
]
=
cur_values
else
:
if
"cuda"
in
device
.
lower
():
if
"cuda"
in
device
.
lower
():
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
#values = GGML_DEQUANTIZE[ggml_name](data)
#print("load_gguf_tensor")
#values = torch.from_numpy(values).to(device = device)
else
:
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
values
=
torch
.
from_numpy
(
values
)
if
ggml_name
==
"BF16"
:
if
ggml_name
==
"BF16"
:
values
=
values
.
view
(
torch
.
bfloat16
)
values
=
values
.
view
(
torch
.
bfloat16
)
values
=
values
.
view
(
shape
[::
-
1
])
values
=
values
.
view
(
shape
[::
-
1
])
if
"attn_q"
in
name
and
self
.
gguf_file_meta
[
'general.architecture'
]
in
[
"llama"
]:
if
"attn_q"
in
name
and
self
.
gguf_file_meta
[
'general.architecture'
]
in
[
"llama"
]:
n_head
=
self
.
gguf_file_meta
[
'llama.attention.head_count'
]
n_head
=
self
.
gguf_file_meta
[
'llama.attention.head_count'
]
...
@@ -433,14 +453,13 @@ def dequantize_q2_k(data):
...
@@ -433,14 +453,13 @@ def dequantize_q2_k(data):
return
d
*
(
scales
&
15
)
*
(
tmp
&
3
)
-
dmin
*
(
scales
>>
4
)
return
d
*
(
scales
&
15
)
*
(
tmp
&
3
)
-
dmin
*
(
scales
>>
4
)
def
dequantize_q2_k_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q2_k_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"Q2_K"
]
block_size
=
GGML_BLOCK_SIZES
[
"Q2_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q2_k
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
return
KTransformersOps
.
dequantize_q2_k
(
data
,
block_size
,
device
)
def
dequantize_q3_k
(
data
):
def
dequantize_q3_k
(
data
):
# C implementation
# C implementation
...
@@ -484,14 +503,13 @@ def dequantize_q3_k(data):
...
@@ -484,14 +503,13 @@ def dequantize_q3_k(data):
(((
qs
[:,
48
:
64
]
>>
6
)
&
3
)
-
bits
[:,
16
:,
7
])
(((
qs
[:,
48
:
64
]
>>
6
)
&
3
)
-
bits
[:,
16
:,
7
])
],
axis
=
1
)
],
axis
=
1
)
def
dequantize_q3_k_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q3_k_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"Q3_K"
]
block_size
=
GGML_BLOCK_SIZES
[
"Q3_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q3_k
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
return
KTransformersOps
.
dequantize_q3_k
(
data
,
block_size
,
device
)
def
dequantize_q4_k
(
data
):
def
dequantize_q4_k
(
data
):
# C implementation
# C implementation
...
@@ -515,13 +533,12 @@ def dequantize_q4_k(data):
...
@@ -515,13 +533,12 @@ def dequantize_q4_k(data):
# Dequantize final weights using scales and offsets
# Dequantize final weights using scales and offsets
return
factors
*
qs2
-
offsets
return
factors
*
qs2
-
offsets
def
dequantize_q4_k_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q4_k_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q4_k
(
data
.
data
,
data
.
size
,
144
,
device
,
target_dtype
)
return
KTransformersOps
.
dequantize_q4_k
(
data
,
144
,
device
)
def
dequantize_q5_k
(
data
):
def
dequantize_q5_k
(
data
):
# C implementation
# C implementation
...
@@ -579,14 +596,13 @@ def dequantize_q5_k(data):
...
@@ -579,14 +596,13 @@ def dequantize_q5_k(data):
d8
*
(
qs_hi_4
[:,
3
]
+
(
bits
[:,
:,
7
]
<<
4
))
-
m8
,
d8
*
(
qs_hi_4
[:,
3
]
+
(
bits
[:,
:,
7
]
<<
4
))
-
m8
,
],
axis
=
1
)
],
axis
=
1
)
def
dequantize_q5_k_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q5_k_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"Q5_K"
]
block_size
=
GGML_BLOCK_SIZES
[
"Q5_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q5_k
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
return
KTransformersOps
.
dequantize_q5_k
(
data
,
block_size
,
device
)
def
dequantize_q6_k
(
data
):
def
dequantize_q6_k
(
data
):
# C implementation
# C implementation
...
@@ -637,13 +653,12 @@ def dequantize_q6_k(data):
...
@@ -637,13 +653,12 @@ def dequantize_q6_k(data):
],
axis
=
1
)
],
axis
=
1
)
# @torch.jit.script
# @torch.jit.script
def
dequantize_q6_k_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
):
def
dequantize_q6_k_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"Q6_K"
]
block_size
=
GGML_BLOCK_SIZES
[
"Q6_K"
]
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
num_blocks
=
len
(
data
)
//
block_size
num_blocks
=
len
(
data
)
//
block_size
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q6_k
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
return
KTransformersOps
.
dequantize_q6_k
(
data
,
block_size
,
device
)
kvalues_iq4nl
=
np
.
array
([
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
],
dtype
=
np
.
int8
)
kvalues_iq4nl
=
np
.
array
([
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
],
dtype
=
np
.
int8
)
...
@@ -677,13 +692,12 @@ def dequantize_iq4_xs(data):
...
@@ -677,13 +692,12 @@ def dequantize_iq4_xs(data):
return
y
.
flatten
()
return
y
.
flatten
()
def
dequantize_iq4_xs_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
):
def
dequantize_iq4_xs_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"IQ4_XS"
]
block_size
=
GGML_BLOCK_SIZES
[
"IQ4_XS"
]
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
num_blocks
=
len
(
data
)
//
block_size
num_blocks
=
len
(
data
)
//
block_size
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_iq4_xs
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
return
KTransformersOps
.
dequantize_iq4_xs
(
data
,
block_size
,
device
)
def
dequantize_q4_0
(
data
):
def
dequantize_q4_0
(
data
):
# C implementation
# C implementation
...
@@ -700,7 +714,7 @@ def dequantize_q4_0(data):
...
@@ -700,7 +714,7 @@ def dequantize_q4_0(data):
scales
*
((
qs
>>
4
).
astype
(
np
.
int8
)
-
8
),
scales
*
((
qs
>>
4
).
astype
(
np
.
int8
)
-
8
),
],
axis
=
1
)
],
axis
=
1
)
def
dequantize_q4_0_gpu
(
data
):
def
dequantize_q4_0_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
dequantize_q5_0
(
data
):
def
dequantize_q5_0
(
data
):
...
@@ -724,7 +738,7 @@ def dequantize_q5_0(data):
...
@@ -724,7 +738,7 @@ def dequantize_q5_0(data):
scales
*
x1
,
scales
*
x1
,
],
axis
=
1
)
],
axis
=
1
)
def
dequantize_q5_0_gpu
(
data
):
def
dequantize_q5_0_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
dequantize_q8_0
(
data
):
def
dequantize_q8_0
(
data
):
...
@@ -736,20 +750,19 @@ def dequantize_q8_0(data):
...
@@ -736,20 +750,19 @@ def dequantize_q8_0(data):
qs
=
np
.
frombuffer
(
data
,
dtype
=
np
.
int8
).
reshape
(
num_blocks
,
2
+
32
)[:,
2
:]
qs
=
np
.
frombuffer
(
data
,
dtype
=
np
.
int8
).
reshape
(
num_blocks
,
2
+
32
)[:,
2
:]
return
scales
*
qs
return
scales
*
qs
def
dequantize_q8_0_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q8_0_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
# C struct definition
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks
=
len
(
data
)
//
GGML_BLOCK_SIZES
[
"Q8_0"
]
num_blocks
=
len
(
data
)
//
GGML_BLOCK_SIZES
[
"Q8_0"
]
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q8_0
(
data
.
data
,
data
.
size
,
34
,
device
,
target_dtype
)
return
KTransformersOps
.
dequantize_q8_0
(
data
,
34
,
device
)
def
dequantize_f32
(
data
):
def
dequantize_f32
(
data
):
return
np
.
frombuffer
(
data
,
dtype
=
np
.
float32
)
return
np
.
frombuffer
(
data
,
dtype
=
np
.
float32
)
def
dequantize_f32_gpu
(
data
,
device
):
def
dequantize_f32_gpu
(
data
,
device
,
target_dtype
=
torch
.
get_default_dtype
()
):
data
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float32
)
data
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float32
)
res
=
torch
.
from_numpy
(
data
)
res
=
torch
.
from_numpy
(
data
)
res_gpu
=
torch
.
empty_like
(
res
,
device
=
device
)
res_gpu
=
torch
.
empty_like
(
res
,
device
=
device
)
...
@@ -759,7 +772,7 @@ def dequantize_f32_gpu(data, device):
...
@@ -759,7 +772,7 @@ def dequantize_f32_gpu(data, device):
def
dequantize_f16
(
data
):
def
dequantize_f16
(
data
):
return
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
)
return
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
)
def
dequantize_f16_gpu
(
data
,
device
):
def
dequantize_f16_gpu
(
data
,
device
,
target_dtype
=
torch
.
get_default_dtype
()
):
data
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
)
data
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
)
res
=
torch
.
from_numpy
(
data
)
res
=
torch
.
from_numpy
(
data
)
res_gpu
=
torch
.
empty_like
(
res
,
device
=
device
)
res_gpu
=
torch
.
empty_like
(
res
,
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment