Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
80815dbc
Unverified
Commit
80815dbc
authored
Aug 12, 2024
by
UnicornChan
Committed by
GitHub
Aug 12, 2024
Browse files
Merge pull request #30 from BITcyman/feature-lc-dequantize_q2k_q3k_gpu
[feature] support q2_k & q3_k dequantize on gpu
parents
650c368c
7c4cb520
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
160 additions
and
11 deletions
+160
-11
ktransformers/ktransformers_ext/cuda/binding.cpp
ktransformers/ktransformers_ext/cuda/binding.cpp
+5
-1
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+5
-0
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+129
-3
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+4
-2
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+17
-5
No files found.
ktransformers/ktransformers_ext/cuda/binding.cpp
View file @
80815dbc
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* @Date : 2024-07-25 13:38:30
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-
09 01:4
5:0
2
* @LastEditTime : 2024-08-
12 03:0
5:0
4
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
...
@@ -27,6 +27,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
...
@@ -27,6 +27,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
View file @
80815dbc
...
@@ -13,6 +13,7 @@ int test(){
...
@@ -13,6 +13,7 @@ int test(){
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
PYBIND11_MODULE
(
cudaops
,
m
)
{
PYBIND11_MODULE
(
cudaops
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
...
@@ -23,6 +24,10 @@ PYBIND11_MODULE(cudaops, m) {
...
@@ -23,6 +24,10 @@ PYBIND11_MODULE(cudaops, m) {
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
}
}
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
80815dbc
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* @Date : 2024-07-25 13:38:30
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-
09 07:57
:0
6
* @LastEditTime : 2024-08-
12 04:18
:0
4
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
...
@@ -36,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
...
@@ -36,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
}
}
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
min
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
82
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
16
);
int
is
=
0
;
float
dl
,
ml
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
uint8_t
*
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
uint8_t
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
))
-
ml
;
scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
(
is
++
));
sc
=
*
scales
;
dl
=
d
*
(
sc
&
0xF
);
ml
=
min
*
(
sc
>>
4
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
))
-
ml
;
shift
+=
2
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
uint32_t
aux
[
4
];
const
int8_t
*
scales
=
(
const
int8_t
*
)
aux
;
const
float
d_all
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
108
)));
const
uint8_t
*
__restrict__
q
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
32
);
const
uint8_t
*
__restrict__
hm
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
0
);
uint8_t
m
=
1
;
uint8_t
*
block_scales
=
(
uint8_t
*
)(
data
+
block_id
*
blk_size
+
96
);
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
aux
[
i
]
=
0
;
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
aux
[
i
]
|=
((
uint32_t
)
block_scales
[
i
*
4
+
j
])
<<
(
j
*
8
);
}
}
uint32_t
tmp
=
aux
[
2
];
aux
[
2
]
=
((
aux
[
0
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
4
)
&
kmask1
)
<<
4
);
aux
[
3
]
=
((
aux
[
1
]
>>
4
)
&
kmask2
)
|
(((
tmp
>>
6
)
&
kmask1
)
<<
4
);
aux
[
0
]
=
(
aux
[
0
]
&
kmask2
)
|
(((
tmp
>>
0
)
&
kmask1
)
<<
4
);
aux
[
1
]
=
(
aux
[
1
]
&
kmask2
)
|
(((
tmp
>>
2
)
&
kmask1
)
<<
4
);
int
is
=
0
;
float
dl
;
for
(
int
n
=
0
;
n
<
256
;
n
+=
128
)
{
int
shift
=
0
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
0
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
0
]
&
m
)
?
0
:
4
));
}
dl
=
d_all
*
(
scales
[
is
++
]
-
32
);
for
(
int
l
=
0
;
l
<
16
;
++
l
)
{
*
output_blk
++
=
dl
*
((
int8_t
)((
q
[
l
+
16
]
>>
shift
)
&
3
)
-
((
hm
[
l
+
16
]
&
m
)
?
0
:
4
));
}
shift
+=
2
;
m
<<=
1
;
}
q
+=
32
;
}
}
}
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
...
@@ -176,6 +267,24 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -176,6 +267,24 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q5_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
// data.numel%blk_size should be 0, else raise err
// data.numel%blk_size should be 0, else raise err
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
...
@@ -196,8 +305,25 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -196,8 +305,25 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
return
output
;
return
output
;
}
}
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
data_gpu
=
torch
::
empty
({
data
.
numel
()},
options
);
data_gpu
.
copy_
(
data
,
false
);
// Create output tensor
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
dequantize_q3_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
return
output
;
}
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
)
{
int
num_blocks
=
data
.
numel
()
/
blk_size
;
int
num_blocks
=
data
.
numel
()
/
blk_size
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt8
).
device
(
device
).
memory_format
(
torch
::
MemoryFormat
::
Contiguous
);
...
@@ -209,7 +335,7 @@ torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device de
...
@@ -209,7 +335,7 @@ torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device de
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
auto
output
=
torch
::
zeros
({
num_blocks
,
256
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
device
));
// Launch kernel
// Launch kernel
dequantize_q
5
_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
dequantize_q
2
_k_kernel
<<<
512
,
256
>>>
(
data_gpu
.
data_ptr
<
int8_t
>
(),
output
.
data_ptr
<
float
>
(),
blk_size
,
num_blocks
);
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
return
output
;
return
output
;
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
80815dbc
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
* @Date : 2024-07-22 09:27:55
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-
09
0
1
:4
4:21
* @LastEditTime : 2024-08-
12
0
3
:4
8:46
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
#pragma once
#pragma once
...
@@ -17,3 +17,5 @@ torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device de
...
@@ -17,3 +17,5 @@ torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device de
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
\ No newline at end of file
ktransformers/util/custom_gguf.py
View file @
80815dbc
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54
Date : 2024-07-26 08:48:54
Version : 1.0.0
Version : 1.0.0
LastEditors : kkk1nak0
LastEditors : kkk1nak0
LastEditTime : 2024-08-
09 08:03:44
LastEditTime : 2024-08-
12 07:21:55
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
Copyright (c) 2024 Thomas Germer
...
@@ -390,8 +390,14 @@ def dequantize_q2_k(data):
...
@@ -390,8 +390,14 @@ def dequantize_q2_k(data):
return
d
*
(
scales
&
15
)
*
(
tmp
&
3
)
-
dmin
*
(
scales
>>
4
)
return
d
*
(
scales
&
15
)
*
(
tmp
&
3
)
-
dmin
*
(
scales
>>
4
)
def
dequantize_q2_k_gpu
(
data
):
def
dequantize_q2_k_gpu
(
data
,
device
:
str
=
"cuda"
):
raise
NotImplementedError
()
block_size
=
GGML_BLOCK_SIZES
[
"Q2_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q2_k
(
data
,
block_size
,
device
)
def
dequantize_q3_k
(
data
):
def
dequantize_q3_k
(
data
):
# C implementation
# C implementation
...
@@ -435,8 +441,14 @@ def dequantize_q3_k(data):
...
@@ -435,8 +441,14 @@ def dequantize_q3_k(data):
(((
qs
[:,
48
:
64
]
>>
6
)
&
3
)
-
bits
[:,
16
:,
7
])
(((
qs
[:,
48
:
64
]
>>
6
)
&
3
)
-
bits
[:,
16
:,
7
])
],
axis
=
1
)
],
axis
=
1
)
def
dequantize_q3_k_gpu
(
data
):
def
dequantize_q3_k_gpu
(
data
,
device
:
str
=
"cuda"
):
raise
NotImplementedError
()
block_size
=
GGML_BLOCK_SIZES
[
"Q3_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q3_k
(
data
,
block_size
,
device
)
def
dequantize_q4_k
(
data
):
def
dequantize_q4_k
(
data
):
# C implementation
# C implementation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment