Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
7e1fe256
Commit
7e1fe256
authored
Feb 21, 2025
by
Atream
Browse files
optimize GPU
parent
cf4da5fd
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
686 additions
and
165 deletions
+686
-165
ktransformers/ktransformers_ext/cuda/binding.cpp
ktransformers/ktransformers_ext/cuda/binding.cpp
+7
-7
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+11
-11
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+599
-100
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+7
-7
ktransformers/local_chat.py
ktransformers/local_chat.py
+1
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+12
-0
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+1
-1
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+48
-35
No files found.
ktransformers/ktransformers_ext/cuda/binding.cpp
View file @
7e1fe256
...
...
@@ -20,19 +20,19 @@
PYBIND11_MODULE
(
KTransformersOps
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
...
...
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
View file @
7e1fe256
...
...
@@ -17,19 +17,19 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
PYBIND11_MODULE
(
cudaops
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
)
,
py
::
arg
(
"target_dtype"
)
);
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
}
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
7e1fe256
This diff is collapsed.
Click to expand it.
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
7e1fe256
...
...
@@ -13,10 +13,10 @@
#include <torch/extension.h>
#include <torch/torch.h>
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_iq4_xs
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q8_0
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q6_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q5_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q4_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q3_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_q2_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
torch
::
Tensor
dequantize_iq4_xs
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
torch
::
Device
device
,
const
torch
::
ScalarType
target_dtype
);
ktransformers/local_chat.py
View file @
7e1fe256
...
...
@@ -168,10 +168,7 @@ def local_chat(
if
mode
==
'long_context'
:
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
torch
.
set_default_dtype
(
torch
.
bfloat16
)
# TODO: Remove this, replace dtype using config
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
7e1fe256
...
...
@@ -5,6 +5,18 @@
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
7e1fe256
...
...
@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
class
KTransformersInterface
(
TransformersInterface
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
):
self
.
args
=
args
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
...
...
ktransformers/util/custom_gguf.py
View file @
7e1fe256
...
...
@@ -285,7 +285,7 @@ class GGUFLoader:
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
def
load_expert_tensor
(
self
,
name
,
data
,
expert_id
,
elements_per_expert
,
device
=
"
gpu"
)
->
torch
.
Tensor
:
def
load_expert_tensor
(
self
,
name
,
data
,
expert_id
,
elements_per_expert
,
device
=
"
cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
print
(
f
"loading expert
{
expert_id
}
of
{
name
}
with CPU"
)
...
...
@@ -304,7 +304,7 @@ class GGUFLoader:
data
=
data
[
offset
:
offset
+
block_size
*
blocks_per_experts
]
if
"cuda"
in
device
.
lower
():
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
,
target_dtype
)
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
...
...
@@ -313,7 +313,7 @@ class GGUFLoader:
return
values
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
)
->
torch
.
Tensor
:
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
,
target_dtype
=
torch
.
get_default_dtype
()
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
print
(
f
"loading
{
name
}
with CPU"
)
...
...
@@ -328,16 +328,36 @@ class GGUFLoader:
data
=
self
.
get_mmap_tensor
(
name
)
if
"cuda"
in
device
.
lower
():
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
#values = GGML_DEQUANTIZE[ggml_name](data)
#print("load_gguf_tensor")
#values = torch.from_numpy(values).to(device = device)
block_size
=
GGML_BLOCK_SIZES
[
ggml_name
]
elements_per_block
=
GGML_ELEMENTS_PER_BLOCK
[
ggml_name
]
num_elements
=
int
(
np
.
prod
(
shape
))
num_blocks
=
num_elements
//
elements_per_block
blocks_per_iter
=
16384
if
num_blocks
>
blocks_per_iter
:
# dequant large tensor
values
=
torch
.
empty
((
num_blocks
,
elements_per_block
),
dtype
=
torch
.
float
,
device
=
device
)
for
i
in
range
(
(
num_blocks
+
blocks_per_iter
-
1
)
//
blocks_per_iter
):
blocks_begin
=
i
*
blocks_per_iter
blocks_end
=
min
(
blocks_begin
+
blocks_per_iter
,
num_blocks
)
if
"cuda"
in
device
.
lower
():
cur_values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
[
blocks_begin
*
block_size
:
blocks_end
*
block_size
],
device
,
target_dtype
)
else
:
cur_values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
[
blocks_begin
*
block_size
:
blocks_end
*
block_size
])
cur_values
=
torch
.
from_numpy
(
cur_values
)
cur_values
=
cur_values
.
view
(
-
1
,
elements_per_block
)
values
[
blocks_begin
:
blocks_end
]
=
cur_values
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
if
"cuda"
in
device
.
lower
():
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
if
ggml_name
==
"BF16"
:
values
=
values
.
view
(
torch
.
bfloat16
)
values
=
values
.
view
(
shape
[::
-
1
])
if
"attn_q"
in
name
and
self
.
gguf_file_meta
[
'general.architecture'
]
in
[
"llama"
]:
n_head
=
self
.
gguf_file_meta
[
'llama.attention.head_count'
]
...
...
@@ -433,14 +453,13 @@ def dequantize_q2_k(data):
return
d
*
(
scales
&
15
)
*
(
tmp
&
3
)
-
dmin
*
(
scales
>>
4
)
def
dequantize_q2_k_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q2_k_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"Q2_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q2_k
(
data
,
block_size
,
device
)
return
KTransformersOps
.
dequantize_q2_k
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
def
dequantize_q3_k
(
data
):
# C implementation
...
...
@@ -484,14 +503,13 @@ def dequantize_q3_k(data):
(((
qs
[:,
48
:
64
]
>>
6
)
&
3
)
-
bits
[:,
16
:,
7
])
],
axis
=
1
)
def
dequantize_q3_k_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q3_k_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"Q3_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q3_k
(
data
,
block_size
,
device
)
return
KTransformersOps
.
dequantize_q3_k
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
def
dequantize_q4_k
(
data
):
# C implementation
...
...
@@ -515,13 +533,12 @@ def dequantize_q4_k(data):
# Dequantize final weights using scales and offsets
return
factors
*
qs2
-
offsets
def
dequantize_q4_k_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q4_k_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q4_k
(
data
,
144
,
device
)
return
KTransformersOps
.
dequantize_q4_k
(
data
.
data
,
data
.
size
,
144
,
device
,
target_dtype
)
def
dequantize_q5_k
(
data
):
# C implementation
...
...
@@ -579,14 +596,13 @@ def dequantize_q5_k(data):
d8
*
(
qs_hi_4
[:,
3
]
+
(
bits
[:,
:,
7
]
<<
4
))
-
m8
,
],
axis
=
1
)
def
dequantize_q5_k_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q5_k_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"Q5_K"
]
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
device
=
torch
.
device
(
device
)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q5_k
(
data
,
block_size
,
device
)
return
KTransformersOps
.
dequantize_q5_k
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
def
dequantize_q6_k
(
data
):
# C implementation
...
...
@@ -637,13 +653,12 @@ def dequantize_q6_k(data):
],
axis
=
1
)
# @torch.jit.script
def
dequantize_q6_k_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
):
def
dequantize_q6_k_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"Q6_K"
]
device
=
torch
.
device
(
device
)
num_blocks
=
len
(
data
)
//
block_size
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q6_k
(
data
,
block_size
,
device
)
return
KTransformersOps
.
dequantize_q6_k
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
kvalues_iq4nl
=
np
.
array
([
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
],
dtype
=
np
.
int8
)
...
...
@@ -677,13 +692,12 @@ def dequantize_iq4_xs(data):
return
y
.
flatten
()
def
dequantize_iq4_xs_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
):
def
dequantize_iq4_xs_gpu
(
data
:
np
.
ndarray
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
block_size
=
GGML_BLOCK_SIZES
[
"IQ4_XS"
]
device
=
torch
.
device
(
device
)
num_blocks
=
len
(
data
)
//
block_size
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_iq4_xs
(
data
,
block_size
,
device
)
return
KTransformersOps
.
dequantize_iq4_xs
(
data
.
data
,
data
.
size
,
block_size
,
device
,
target_dtype
)
def
dequantize_q4_0
(
data
):
# C implementation
...
...
@@ -700,7 +714,7 @@ def dequantize_q4_0(data):
scales
*
((
qs
>>
4
).
astype
(
np
.
int8
)
-
8
),
],
axis
=
1
)
def
dequantize_q4_0_gpu
(
data
):
def
dequantize_q4_0_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
raise
NotImplementedError
()
def
dequantize_q5_0
(
data
):
...
...
@@ -724,7 +738,7 @@ def dequantize_q5_0(data):
scales
*
x1
,
],
axis
=
1
)
def
dequantize_q5_0_gpu
(
data
):
def
dequantize_q5_0_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
raise
NotImplementedError
()
def
dequantize_q8_0
(
data
):
...
...
@@ -736,20 +750,19 @@ def dequantize_q8_0(data):
qs
=
np
.
frombuffer
(
data
,
dtype
=
np
.
int8
).
reshape
(
num_blocks
,
2
+
32
)[:,
2
:]
return
scales
*
qs
def
dequantize_q8_0_gpu
(
data
,
device
:
str
=
"cuda"
):
def
dequantize_q8_0_gpu
(
data
,
device
:
str
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
()
):
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks
=
len
(
data
)
//
GGML_BLOCK_SIZES
[
"Q8_0"
]
device
=
torch
.
device
(
device
)
data
=
np
.
frombuffer
(
data
,
dtype
=
data
.
dtype
)
data
=
torch
.
from_numpy
(
data
)
return
KTransformersOps
.
dequantize_q8_0
(
data
,
34
,
device
)
return
KTransformersOps
.
dequantize_q8_0
(
data
.
data
,
data
.
size
,
34
,
device
,
target_dtype
)
def
dequantize_f32
(
data
):
return
np
.
frombuffer
(
data
,
dtype
=
np
.
float32
)
def
dequantize_f32_gpu
(
data
,
device
):
def
dequantize_f32_gpu
(
data
,
device
,
target_dtype
=
torch
.
get_default_dtype
()
):
data
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float32
)
res
=
torch
.
from_numpy
(
data
)
res_gpu
=
torch
.
empty_like
(
res
,
device
=
device
)
...
...
@@ -759,7 +772,7 @@ def dequantize_f32_gpu(data, device):
def
dequantize_f16
(
data
):
return
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
)
def
dequantize_f16_gpu
(
data
,
device
):
def
dequantize_f16_gpu
(
data
,
device
,
target_dtype
=
torch
.
get_default_dtype
()
):
data
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float16
)
res
=
torch
.
from_numpy
(
data
)
res_gpu
=
torch
.
empty_like
(
res
,
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment