Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
be356c1b
"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "fb9af9f36f131daf3dde1416faa4aa3cd1f203c7"
Commit
be356c1b
authored
Sep 02, 2024
by
Yap Sok Ann
Browse files
Support IQ4_XS dequantize
parent
022b8938
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
2 deletions
+93
-2
ktransformers/ktransformers_ext/cuda/binding.cpp
ktransformers/ktransformers_ext/cuda/binding.cpp
+2
-0
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+2
-0
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+42
-1
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+2
-1
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+45
-0
No files found.
ktransformers/ktransformers_ext/cuda/binding.cpp
View file @
be356c1b
...
@@ -31,6 +31,8 @@ PYBIND11_MODULE(KTransformersOps, m) {
...
@@ -31,6 +31,8 @@ PYBIND11_MODULE(KTransformersOps, m) {
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
View file @
be356c1b
...
@@ -28,6 +28,8 @@ PYBIND11_MODULE(cudaops, m) {
...
@@ -28,6 +28,8 @@ PYBIND11_MODULE(cudaops, m) {
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
}
}
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
be356c1b
...
@@ -212,6 +212,29 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -212,6 +212,29 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
}
}
}
}
static
constexpr
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
__global__
void
dequantize_iq4_xs_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
)));
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
const
uint8_t
*
scales_l
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
);
const
uint8_t
*
qs
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
2
+
2
+
4
);
for
(
int
ib
=
0
;
ib
<
8
;
++
ib
)
{
const
int
ls
=
((
scales_l
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
)
|
(((
scales_h
>>
2
*
ib
)
&
3
)
<<
4
);
const
float
dl
=
d
*
(
ls
-
32
);
for
(
int
j
=
0
;
j
<
16
;
++
j
)
{
output_blk
[
j
+
0
]
=
dl
*
kvalues_iq4nl
[
qs
[
j
]
&
0xf
];
output_blk
[
j
+
16
]
=
dl
*
kvalues_iq4nl
[
qs
[
j
]
>>
4
];
}
output_blk
+=
32
;
qs
+=
16
;
}
}
}
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
...
@@ -339,4 +362,22 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -339,4 +362,22 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
}
}
\ No newline at end of file
torch
::
Tensor
dequantize_iq4_xs
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_iq4_xs_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
be356c1b
...
@@ -18,4 +18,5 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -18,4 +18,5 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
\ No newline at end of file
torch
::
Tensor
dequantize_iq4_xs
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
ktransformers/util/custom_gguf.py
View file @
be356c1b
...
@@ -108,6 +108,7 @@ GGML_TYPES = {
...
@@ -108,6 +108,7 @@ GGML_TYPES = {
"Q4_K"
:
12
,
"Q4_K"
:
12
,
"Q5_K"
:
13
,
"Q5_K"
:
13
,
"Q6_K"
:
14
,
"Q6_K"
:
14
,
"IQ4_XS"
:
23
,
}
}
GGML_NAMES
=
{
ggml_type
:
name
for
name
,
ggml_type
in
GGML_TYPES
.
items
()}
GGML_NAMES
=
{
ggml_type
:
name
for
name
,
ggml_type
in
GGML_TYPES
.
items
()}
...
@@ -123,6 +124,7 @@ GGML_BLOCK_SIZES = {
...
@@ -123,6 +124,7 @@ GGML_BLOCK_SIZES = {
"Q4_K"
:
2
+
2
+
12
+
256
//
2
,
"Q4_K"
:
2
+
2
+
12
+
256
//
2
,
"Q5_K"
:
2
+
2
+
12
+
256
//
8
+
256
//
2
,
"Q5_K"
:
2
+
2
+
12
+
256
//
8
+
256
//
2
,
"Q6_K"
:
256
//
2
+
256
//
4
+
256
//
16
+
2
,
"Q6_K"
:
256
//
2
+
256
//
4
+
256
//
16
+
2
,
"IQ4_XS"
:
2
+
2
+
256
//
2
+
256
//
64
,
}
}
GGML_ELEMENTS_PER_BLOCK
=
{
GGML_ELEMENTS_PER_BLOCK
=
{
...
@@ -136,6 +138,7 @@ GGML_ELEMENTS_PER_BLOCK = {
...
@@ -136,6 +138,7 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q4_K"
:
256
,
"Q4_K"
:
256
,
"Q5_K"
:
256
,
"Q5_K"
:
256
,
"Q6_K"
:
256
,
"Q6_K"
:
256
,
"IQ4_XS"
:
256
,
}
}
DATA_TYPES
=
{
DATA_TYPES
=
{
...
@@ -601,6 +604,46 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
...
@@ -601,6 +604,46 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
data
=
torch
.
from_numpy
(
data
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q6_k
(
data
,
block_size
,
device
)
return
KTransformersOps
.
dequantize_q6_k
(
data
,
block_size
,
device
)
kvalues_iq4nl
=
np
.
array
([
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
],
dtype
=
np
.
int8
)
def
dequantize_iq4_xs
(
data
):
# C implementation
# https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-quants.c#L3568
# C struct definition
# https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-common.h#L393
block_size
=
GGML_BLOCK_SIZES
[
"IQ4_XS"
]
num_blocks
=
len
(
data
)
//
block_size
d
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
)[
0
::
block_size
//
2
].
astype
(
np
.
float32
).
reshape
(
num_blocks
,
1
)
scales_h
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint16
)[
1
::
block_size
//
2
].
reshape
(
num_blocks
,
1
)
data_u8
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint8
).
reshape
(
num_blocks
,
block_size
)[:,
4
:]
scales_l
=
data_u8
[:,
:
4
].
reshape
(
num_blocks
,
4
)
qs
=
data_u8
[:,
4
:].
reshape
(
num_blocks
,
block_size
-
8
)
ls
=
np
.
zeros
((
num_blocks
,
QK_K
//
32
),
dtype
=
np
.
int8
)
for
ib
in
range
(
QK_K
//
32
):
ls
[:,
ib
]
=
((
scales_l
[:,
ib
//
2
]
>>
4
*
(
ib
%
2
))
&
0xf
)
|
(((
scales_h
[:,
0
]
>>
2
*
ib
)
&
3
)
<<
4
)
dl
=
(
d
*
(
ls
-
32
)).
reshape
(
num_blocks
,
-
1
,
1
)
qs_lo_4
=
qs
[:,
:
QK_K
//
2
].
reshape
(
num_blocks
,
-
1
,
16
)
&
0xf
qs_hi_4
=
qs
[:,
:
QK_K
//
2
].
reshape
(
num_blocks
,
-
1
,
16
)
>>
4
y
=
np
.
zeros
((
num_blocks
,
QK_K
),
dtype
=
np
.
float32
)
for
ib
in
range
(
QK_K
//
32
):
y
[:,
ib
*
32
:(
ib
*
32
)
+
16
]
=
dl
[:,
ib
]
*
kvalues_iq4nl
[
qs_lo_4
[:,
ib
]]
y
[:,
(
ib
*
32
)
+
16
:(
ib
*
32
)
+
32
]
=
dl
[:,
ib
]
*
kvalues_iq4nl
[
qs_hi_4
[:,
ib
]]
return
y
.
flatten
()
def
dequantize_iq4_xs_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
):
block_size
=
GGML_BLOCK_SIZES
[
"IQ4_XS"
]
device
=
torch
.
device
(
device
)
num_blocks
=
len
(
data
)
//
block_size
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_iq4_xs
(
data
,
block_size
,
device
)
def
dequantize_q4_0
(
data
):
def
dequantize_q4_0
(
data
):
# C implementation
# C implementation
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515
# https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515
...
@@ -693,6 +736,7 @@ GGML_DEQUANTIZE = {
...
@@ -693,6 +736,7 @@ GGML_DEQUANTIZE = {
"Q4_K"
:
dequantize_q4_k
,
"Q4_K"
:
dequantize_q4_k
,
"Q5_K"
:
dequantize_q5_k
,
"Q5_K"
:
dequantize_q5_k
,
"Q6_K"
:
dequantize_q6_k
,
"Q6_K"
:
dequantize_q6_k
,
"IQ4_XS"
:
dequantize_iq4_xs
,
}
}
GGML_DEQUANTIZE_GPU
=
{
GGML_DEQUANTIZE_GPU
=
{
...
@@ -706,6 +750,7 @@ GGML_DEQUANTIZE_GPU = {
...
@@ -706,6 +750,7 @@ GGML_DEQUANTIZE_GPU = {
"Q4_K"
:
dequantize_q4_k_gpu
,
"Q4_K"
:
dequantize_q4_k_gpu
,
"Q5_K"
:
dequantize_q5_k_gpu
,
"Q5_K"
:
dequantize_q5_k_gpu
,
"Q6_K"
:
dequantize_q6_k_gpu
,
"Q6_K"
:
dequantize_q6_k_gpu
,
"IQ4_XS"
:
dequantize_iq4_xs_gpu
,
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment