Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
7c4cb520
Commit
7c4cb520
authored
Aug 12, 2024
by
BITcyman
Browse files
[feature] support q2_k & q3_k dequantize on gpu
parent
650c368c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
160 additions
and
11 deletions
+160
-11
ktransformers/ktransformers_ext/cuda/binding.cpp
ktransformers/ktransformers_ext/cuda/binding.cpp
+5
-1
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+5
-0
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+129
-3
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+4
-2
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+17
-5
No files found.
ktransformers/ktransformers_ext/cuda/binding.cpp
View file @
7c4cb520
...
...
@@ -4,7 +4,7 @@
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-
09 01:4
5:0
2
* @LastEditTime : 2024-08-
12 03:0
5:0
4
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
...
...
@@ -27,6 +27,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
View file @
7c4cb520
...
...
@@ -13,6 +13,7 @@ int test(){
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
PYBIND11_MODULE
(
cudaops
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
...
...
@@ -23,6 +24,10 @@ PYBIND11_MODULE(cudaops, m) {
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
}
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
7c4cb520
...
...
@@ -4,7 +4,7 @@
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-
09 07:57
:0
6
* @LastEditTime : 2024-08-
12 04:18
:0
4
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
...
...
@@ -36,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
82
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
int
is
=
0
;
float
dl
,
ml
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
uint8_t
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
))
-
ml
;
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
))
-
ml
;
shift
+=
2
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
uint32_t
aux
[
4
];
const
int8_t
*
scales
=
(
const
int8_t
*
)
aux
;
const
float
d_all
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
108
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
32
);
const
uint8_t
*
__restrict__
hm
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
0
);
uint8_t
m
=
1
;
uint8_t
*
block_scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
96
);
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
aux
[
i
]
=
0
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
aux
[
i
]
|=
((
uint32_t
)
block_scales
[
i
*
4
+
j
])
<<
(
j
*
8
);
}
}
uint32_t
tmp
=
aux
[
2
];
aux
[
2
]
=
((
aux
[
0
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
4
)
&
kmask1
)
<<
4
);
aux
[
3
]
=
((
aux
[
1
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
6
)
&
kmask1
)
<<
4
);
aux
[
0
]
=
(
aux
[
0
]
&
kmask2
)
|
(((
tmp
>>
0
)
&
kmask1
)
<<
4
);
aux
[
1
]
=
(
aux
[
1
]
&
kmask2
)
|
(((
tmp
>>
2
)
&
kmask1
)
<<
4
);
int
is
=
0
;
float
dl
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
0
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
0
]
&
m
)
?
0
:
4
));
}
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
16
]
&
m
)
?
0
:
4
));
}
shift
+=
2
;
m
<<=
1
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
...
...
@@ -176,6 +267,24 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
return
output
;
}
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q5_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
...
...
@@ -196,8 +305,25 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
return
output
;
}
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q3_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
...
...
@@ -209,7 +335,7 @@ torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device de
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q
5
_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
dequantize_q
2
_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
7c4cb520
...
...
@@ -4,7 +4,7 @@
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-
09
0
1
:4
4:21
* @LastEditTime : 2024-08-
12
0
3
:4
8:46
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
...
...
@@ -16,4 +16,6 @@
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
\ No newline at end of file
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
\ No newline at end of file
ktransformers/util/custom_gguf.py
View file @
7c4cb520
...
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54
Version : 1.0.0
LastEditors : kkk1nak0
LastEditTime : 2024-08-
09 08:03:44
LastEditTime : 2024-08-
12 07:21:55
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
...
...
@@ -390,8 +390,14 @@ def dequantize_q2_k(data):
return
d
*
(
scales
&
15
)
*
(
tmp
&
3
)
-
dmin
*
(
scales
>>
4
)
def
dequantize_q2_k_gpu
(
data
):
raise
NotImplementedError
()
def
dequantize_q2_k_gpu
(
data
,
device
:
str
=
"cuda"
):
block_size
=
GGML_BLOCK_SIZES
[
"Q2_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q2_k
(
data
,
block_size
,
device
)
def
dequantize_q3_k
(
data
):
# C implementation
...
...
@@ -435,8 +441,14 @@ def dequantize_q3_k(data):
(((
qs
[:,
48
:
64
]
>>
6
)
&
3
)
-
bits
[:,
16
:,
7
])
],
axis
=
1
)
def
dequantize_q3_k_gpu
(
data
):
raise
NotImplementedError
()
def
dequantize_q3_k_gpu
(
data
,
device
:
str
=
"cuda"
):
block_size
=
GGML_BLOCK_SIZES
[
"Q3_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q3_k
(
data
,
block_size
,
device
)
def
dequantize_q4_k
(
data
):
# C implementation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment