Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
581a524f
Commit
581a524f
authored
Feb 24, 2025
by
Azure
Browse files
Add data loader to read special weights for fp8; Add special weight process script
parent
7b7c6a65
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
481 additions
and
26 deletions
+481
-26
ktransformers/ktransformers_ext/triton/fp8gemm.py
ktransformers/ktransformers_ext/triton/fp8gemm.py
+1
-0
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+10
-1
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+10
-3
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+21
-17
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
...imize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
+63
-0
ktransformers/tests/triton_fp8gemm_test.py
ktransformers/tests/triton_fp8gemm_test.py
+45
-2
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+18
-1
ktransformers/util/custom_loader.py
ktransformers/util/custom_loader.py
+86
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+13
-2
merge_tensors/merge_safetensor_gguf.py
merge_tensors/merge_safetensor_gguf.py
+214
-0
No files found.
ktransformers/ktransformers_ext/triton/fp8gemm.py
View file @
581a524f
# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
from
typing
import
Tuple
import
torch
...
...
ktransformers/operators/experts.py
View file @
581a524f
...
...
@@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
down_type
=
None
for
key
in
keys
:
if
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
if
self
.
gguf_loader
.
safetensor_loader
is
not
None
:
# using a temp ugly way to temprary load the tensor
gate
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_gate_exps.weight"
).
numpy
()
up
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_up_exps.weight"
).
numpy
()
down
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_down_exps.weight"
).
numpy
()
gate_type
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_gate_exps.ggml_type"
).
item
()
up_type
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_up_exps.ggml_type"
).
item
()
down_type
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_down_exps.ggml_type"
).
item
()
elif
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
gate
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_gate_exps.weight"
)
up
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_up_exps.weight"
)
down
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_down_exps.weight"
)
...
...
ktransformers/operators/gate.py
View file @
581a524f
...
...
@@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
for
key
in
keys
:
key
=
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
if
key
+
".ffn_gate_inp.weight"
in
self
.
gguf_loader
.
tensor_info
:
if
self
.
gguf_loader
.
safetensor_loader
is
not
None
:
targets
=
[
".ffn_gate_inp.weight"
,
".exp_probs_b.bias"
]
weight
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_gate_inp.weight"
)
e_score_correction_bias
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".exp_probs_b.bias"
)
weight_type
=
weight
.
dtype
e_score_correction_bias_type
=
e_score_correction_bias
.
dtype
res
=
{
"weight"
:
weight
,
"e_score_correction_bias"
:
e_score_correction_bias
,
"weight_type"
:
weight_type
,
"e_score_correction_bias_type"
:
e_score_correction_bias_type
}
elif
key
+
".ffn_gate_inp.weight"
in
self
.
gguf_loader
.
tensor_info
:
targets
=
[
".ffn_gate_inp.weight"
,
".exp_probs_b.bias"
]
tensors
=
self
.
load_multi
(
key
,
targets
,
device
=
device
)
weight
=
tensors
[
".ffn_gate_inp.weight"
]
...
...
@@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
w
[
"e_score_correction_bias"
])
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
self
.
orig_module
.
weight
=
nn
.
Parameter
(
self
.
orig_module
.
weight
.
to
(
device
)
)
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
...
...
ktransformers/operators/linear.py
View file @
581a524f
...
...
@@ -76,7 +76,13 @@ class KLinearBase(ABC):
keys
=
[
self
.
key
]
for
key
in
keys
:
if
key
+
".weight"
in
self
.
gguf_loader
.
tensor_file_map
:
if
self
.
gguf_loader
.
safetensor_loader
is
not
None
:
# using safetensor_loader
tensor
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
'.weight'
)
weight_scale_inv
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
'.weight_scale_inv'
)
return
nn
.
Parameter
(
tensor
),
nn
.
Parameter
(
weight_scale_inv
)
elif
key
+
".weight"
in
self
.
gguf_loader
.
tensor_file_map
:
if
key
+
".bias"
in
self
.
gguf_loader
.
tensor_file_map
:
tensors
=
self
.
load_multi
(
key
,
[
"weight"
,
"bias"
],
device
=
device
)
tensor
=
tensors
[
"weight"
]
...
...
@@ -166,6 +172,8 @@ class KLinearTorch(KLinearBase):
self
.
bias
=
None
class
KLinearFP8
(
KLinearBase
):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
...
...
@@ -191,26 +199,20 @@ class KLinearFP8(KLinearBase):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
.
to
(
self
.
device
)
orig_shape
=
list
(
x
.
shape
)
orig_dtype
=
x
.
dtype
x
=
x
.
reshape
(
-
1
,
orig_shape
[
-
1
])
orig_dtype
=
x
.
dtype
x_quantized
,
scale_x
=
act_quant
(
x
,
self
.
block_size
)
y
=
fp8_gemm
(
x_quantized
,
scale_x
,
self
.
weight
,
self
.
weight
.
scale
)
if
self
.
bias
is
not
None
:
y
+=
self
.
bias
return
y
.
to
(
orig_dtype
).
reshape
(
orig_shape
)
y
=
fp8_gemm
(
x_quantized
,
scale_x
,
self
.
weight
,
self
.
weight_scale_inv
)
return
y
.
to
(
dtype
=
orig_dtype
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
self
.
weight
=
w
.
to
(
device
)
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
### TODO fit weight_inv format
if
isinstance
(
w
,
tuple
):
self
.
weight
=
w
[
0
].
to
(
device
)
self
.
bias
=
w
[
1
].
to
(
device
)
self
.
has_bias
=
Tru
e
self
.
weight_scale_inv
=
w
[
1
].
to
(
device
)
self
.
has_bias
=
Fals
e
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
weight
=
self
.
weight
.
to
(
device
)
...
...
@@ -425,7 +427,8 @@ class KLinearCPUInfer(KLinearBase):
LINEAR_MAP
=
{
"KLinearMarlin"
:
KLinearMarlin
,
"KLinearTorch"
:
KLinearTorch
,
"KLinearCPUInfer"
:
KLinearCPUInfer
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearFP8"
:
KLinearFP8
,
}
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
...
...
@@ -472,10 +475,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
def
forward
(
self
,
x
):
if
self
.
mode
==
InferenceState
.
PREFILL
:
assert
self
.
prefill_linear
is
not
None
,
"cpu linear is not initialized"
return
self
.
prefill_linear
.
forward
(
x
)
y
=
self
.
prefill_linear
.
forward
(
x
)
else
:
assert
self
.
generate_linear
is
not
None
,
"gpu linear is not initialized"
return
self
.
generate_linear
.
forward
(
x
)
y
=
self
.
generate_linear
.
forward
(
x
)
return
y
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
mode
:
InferenceState
=
InferenceState
.
GENERATE
):
if
not
mode
:
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
0 → 100644
View file @
581a524f
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/tests/triton_fp8gemm_test.py
View file @
581a524f
...
...
@@ -3,7 +3,7 @@ import torch.nn.functional as F
from
typing
import
Optional
import
pytest
from
typing
import
Tuple
,
Optional
,
Literal
import
time
# use dir path
import
os
import
sys
...
...
@@ -56,18 +56,61 @@ def test_fp8_gemm_vs_torch_matmul_load():
print
(
f
"weight_dequantized:
{
weight_dequantized
.
shape
}
"
)
N
,
K
=
weight_dequantized
.
shape
M
=
64
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x
=
torch
.
randn
(
2
,
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
# Test case 1: quantized x matmal with undequantized weight
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight
,
scale
)
print
(
f
"result_fp8_gemm:
\n
{
result_fp8_gemm
}
"
)
print
(
f
"dtype
{
result_fp8_gemm
.
dtype
}
"
)
# Perform torch.matmul using the original floating point tensors
result_torch_matmul
=
torch
.
matmul
(
x
,
weight_dequantized
.
to
(
torch
.
bfloat16
).
T
)
print
(
f
"result_torch_matmul:
\n
{
result_torch_matmul
}
"
)
def
test_fp8_gemm_tplops
():
file_path
=
"/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
0
)
as
f
:
weight
=
f
.
get_tensor
(
"model.layers.0.mlp.down_proj.weight"
)
scale
=
f
.
get_tensor
(
"model.layers.0.mlp.down_proj.weight_scale_inv"
)
# weight_dequant
weight_dequantized
=
weight_dequant
(
weight
,
scale
)
print
(
f
"weight_dequantized:
{
weight_dequantized
.
shape
}
"
)
N
,
K
=
weight_dequantized
.
shape
M
=
6400
x
=
torch
.
randn
(
2
,
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
# x_quantized, scale_x = act_quant(x, block_size)
# Calculate time for 1000 fp8_gemm
i
=
10
flops_per_gemm
=
2
*
M
*
N
*
K
total_flops
=
i
*
flops_per_gemm
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight
,
scale
)
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight
,
scale
)
t0
=
time
.
time
()
torch
.
cuda
.
synchronize
()
for
i
in
range
(
i
):
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight
,
scale
)
torch
.
cuda
.
synchronize
()
t1
=
time
.
time
()
total_time
=
t1
-
t0
tflops
=
total_flops
/
total_time
/
1e12
print
(
f
"total_time:
{
total_time
}
"
)
print
(
f
"tflops:
{
tflops
}
"
)
if
__name__
==
"__main__"
:
test_fp8_gemm_vs_torch_matmul
()
test_fp8_gemm_vs_torch_matmul_load
()
test_fp8_gemm_tplops
()
\ No newline at end of file
ktransformers/util/custom_gguf.py
View file @
581a524f
...
...
@@ -25,6 +25,7 @@ import os
from
enum
import
IntEnum
import
torch
import
KTransformersOps
from
.custom_loader
import
SafeTensorLoader
class
GGMLQuantizationType
(
IntEnum
):
F32
=
0
...
...
@@ -168,12 +169,15 @@ class GGUFLoader:
gguf_path
:
str
tensor_file_map
:
dict
# {tensor_name: tensor_file_path}
gguf_file_meta
:
dict
safetensor_loader
:
SafeTensorLoader
def
__init__
(
self
,
gguf_path
:
str
):
# Check dir exist
if
not
os
.
path
.
exists
(
gguf_path
):
raise
FileNotFoundError
(
f
"GGUF dir not found:
{
gguf_path
}
"
)
if
os
.
path
.
isfile
(
gguf_path
):
gguf_path
=
os
.
path
.
dirname
(
gguf_path
)
self
.
safetensor_loader
=
None
self
.
tensor_info
=
{}
self
.
gguf_path
=
gguf_path
...
...
@@ -181,7 +185,13 @@ class GGUFLoader:
self
.
file_data_map
=
{}
self
.
gguf_file_meta
=
{}
self
.
tensor_device_map
=
{}
# I know this is ugly, but I don't want to change the original code too much
# TODO: merge gguf load and other loads.
safetensor_loader
=
SafeTensorLoader
(
gguf_path
)
if
safetensor_loader
.
tensor_file_map
:
self
.
safetensor_loader
=
safetensor_loader
return
# Walk through all the .gguf files in the directory
found_gguf
=
False
for
root
,
dirs
,
files
in
os
.
walk
(
gguf_path
):
...
...
@@ -288,6 +298,13 @@ class GGUFLoader:
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
def
get_undequanted_tensor_and_ggml_type
(
self
,
name
):
t
=
self
.
tensor_info
[
name
]
data
=
self
.
get_mmap_tensor
(
name
)
ggml_type
=
t
[
"ggml_type"
]
data
=
torch
.
from_numpy
(
data
)
return
data
,
ggml_type
def
load_expert_tensor
(
self
,
name
,
data
,
expert_id
,
elements_per_expert
,
device
=
"gpu"
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
...
...
ktransformers/util/custom_loader.py
0 → 100644
View file @
581a524f
import
struct
import
warnings
import
numpy
as
np
import
re
import
numpy.typing
as
npt
from
typing
import
Sequence
import
os
from
enum
import
IntEnum
import
torch
import
KTransformersOps
from
safetensors
import
safe_open
from
ktransformers.ktransformers_ext.triton.fp8gemm
import
fp8_gemm
,
act_quant
,
weight_dequant
from
safetensors.torch
import
save_file
class
SafeTensorLoader
:
tensor_file_map
=
{}
tensor_type_map
=
{}
file_handle_map
=
{}
def
__init__
(
self
,
file_path
:
str
):
self
.
__load_tensor_file_map
(
file_path
)
def
__load_tensor_file_map
(
self
,
file_path
:
str
):
# 处理传入路径,确保是文件夹路径
if
not
os
.
path
.
exists
(
file_path
):
raise
FileNotFoundError
(
f
"Path not found:
{
file_path
}
"
)
if
os
.
path
.
isfile
(
file_path
):
folder_path
=
os
.
path
.
dirname
(
file_path
)
else
:
folder_path
=
file_path
found_safetensor
=
False
for
root
,
_
,
files
in
os
.
walk
(
folder_path
):
files
=
sorted
(
files
)
for
file
in
files
:
if
file
.
endswith
(
".safetensors"
):
found_safetensor
=
True
file_path
=
os
.
path
.
join
(
root
,
file
)
if
file
not
in
self
.
file_handle_map
:
try
:
handle
=
safe_open
(
file_path
,
framework
=
"pt"
)
self
.
file_handle_map
[
file
]
=
handle
except
Exception
as
e
:
print
(
f
"Error opening Safetensor file
{
file_path
}
:
{
e
}
"
)
continue
f
=
self
.
file_handle_map
.
get
(
file
)
if
f
is
None
:
continue
try
:
for
key
in
f
.
keys
():
self
.
tensor_file_map
[
key
]
=
file
except
Exception
as
e
:
print
(
f
"Error reading Safetensor file
{
file_path
}
:
{
e
}
"
)
# if not found_safetensor:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def
load_tensor
(
self
,
key
:
str
,
device
:
str
=
"cpu"
):
if
key
not
in
self
.
tensor_file_map
:
raise
KeyError
(
f
"Key
{
key
}
not found in Safetensor files"
)
file
=
self
.
tensor_file_map
[
key
]
f
=
self
.
file_handle_map
.
get
(
file
)
if
f
is
None
:
raise
FileNotFoundError
(
f
"File
{
file
}
not found in Safetensor files"
)
tensor
=
f
.
get_tensor
(
key
)
return
tensor
.
to
(
device
)
def
close_all_handles
(
self
):
for
handle
in
self
.
file_handle_map
.
values
():
handle
.
close
()
self
.
file_handle_map
.
clear
()
def
load_dequantized_tensor
(
self
,
key
:
str
,
device
:
str
=
"cpu"
):
if
key
not
in
self
.
tensor_file_map
:
raise
KeyError
(
f
"Key
{
key
}
not found in Safetensor files"
)
file
=
self
.
tensor_file_map
[
key
]
f
=
self
.
file_handle_map
.
get
(
file
)
if
f
is
None
:
raise
FileNotFoundError
(
f
"File
{
file
}
not found in Safetensor files"
)
tensor
=
f
.
get_tensor
(
key
).
to
(
device
)
if
key
.
endswith
(
".weight"
):
if
key
[:
-
7
]
+
".weight_scale_inv"
in
self
.
tensor_file_map
:
weight_scale_inv
=
f
.
get_tensor
(
key
[:
-
7
]
+
".weight_scale_inv"
).
to
(
device
)
tensor
=
weight_dequant
(
tensor
,
weight_scale_inv
)
return
tensor
.
to
(
device
)
\ No newline at end of file
ktransformers/util/utils.py
View file @
581a524f
...
...
@@ -66,12 +66,23 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for
name
,
param
in
local_state
.
items
():
key
=
prefix
+
name
translated_key
=
translate_name_to_gguf
(
key
)
if
translated_key
in
gguf_loader
.
tensor_file_map
:
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if
gguf_loader
.
safetensor_loader
is
not
None
:
load_dequantized_tensor
=
gguf_loader
.
safetensor_loader
.
load_dequantized_tensor
tensor_file_map
=
gguf_loader
.
safetensor_loader
.
tensor_file_map
else
:
load_dequantized_tensor
=
gguf_loader
.
load_gguf_tensor
tensor_file_map
=
gguf_loader
.
tensor_file_map
if
translated_key
in
tensor_file_map
:
target_dtype
=
torch
.
get_default_dtype
()
device
=
get_device
(
translated_key
[:
translated_key
.
rfind
(
"."
)],
gguf_loader
.
tensor_device_map
)
print
(
f
"loading
{
translated_key
}
to
{
device
}
"
)
torch
.
cuda
.
empty_cache
()
# To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
weights
=
gguf_loader
.
load_gguf_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
weights
=
load_dequantized_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
set_param
(
module
,
name
,
weights
)
del
weights
else
:
...
...
merge_tensors/merge_safetensor_gguf.py
0 → 100644
View file @
581a524f
# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.
import
os
# insert the path of the project
import
sys
sys
.
path
.
insert
(
0
,
"/home/azure/ktransformers"
)
import
argparse
import
torch
from
ktransformers.util.custom_gguf
import
GGUFLoader
,
translate_name_to_gguf
from
safetensors
import
safe_open
from
safetensors.torch
import
save_file
import
re
from
collections
import
defaultdict
def
read_safetensor_keys_from_folder
(
folder_path
)
->
dict
:
"""
:param folder_path: folder path
:return: key_to_file_map
"""
# check if the folder path is exist
if
not
os
.
path
.
exists
(
folder_path
):
raise
FileNotFoundError
(
f
"GGUF dir not found:
{
folder_path
}
"
)
if
os
.
path
.
isfile
(
folder_path
):
folder_path
=
os
.
path
.
dirname
(
folder_path
)
key_to_file_map
=
{}
found_safetensor
=
False
for
root
,
dirs
,
files
in
os
.
walk
(
folder_path
):
# sort files
files
=
sorted
(
files
)
for
file
in
files
:
if
file
.
endswith
(
".safetensors"
):
found_safetensor
=
True
file_path
=
os
.
path
.
join
(
root
,
file
)
try
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
if
"model.layers.61"
in
key
:
# skip MTP layer
continue
# try:
# if int(key.split('.')[2]) > 4:
# continue
# except:
# pass
key_to_file_map
[
key
]
=
file_path
except
Exception
as
e
:
print
(
f
"Error reading Safetensor file
{
file_path
}
:
{
e
}
"
)
if
not
found_safetensor
:
raise
FileNotFoundError
(
f
"No Safetensor files found in
{
folder_path
}
"
)
return
key_to_file_map
tensor_from_gguf
=
[]
# todo: add keys in gguf that should be used in the final tensor
def
translate_name
(
name
:
str
)
->
str
:
"""
:param name: name of the tensor
:return: translated name
"""
name
=
translate_name_to_gguf
(
name
)
name
=
name
.
replace
(
".up_proj."
,
".ffn_up_exps."
)
name
=
name
.
replace
(
".down_proj."
,
".ffn_down_exps."
)
name
=
name
.
replace
(
".gate_proj."
,
".ffn_gate_exps."
)
name
=
name
.
replace
(
".ffn_gate_inp.e_score_correction_bias"
,
".exp_probs_b.bias"
)
return
name
def
combine_tensor_sources
(
safetensor_path
:
str
,
gguf_path
:
str
):
gguf_loader
=
GGUFLoader
(
gguf_path
)
gguf_tensor_file_map
=
gguf_loader
.
tensor_file_map
safetensor_tensor_file_map
=
read_safetensor_keys_from_folder
(
safetensor_path
)
# build a map for the key to the tensor
# according to the key, we can get the tensor from the file
target_tensor_map
=
{}
for
key
in
safetensor_tensor_file_map
.
keys
():
# for all experts, we use the gguf tensor
if
".mlp.experts."
in
key
:
if
'.weight_scale_inv'
in
key
:
continue
key
=
'.'
.
join
(
key
.
split
(
'.'
)[:
5
]
+
key
.
split
(
'.'
)[
-
2
:])
translated_key
=
translate_name
(
key
)
target_tensor_map
[
key
]
=
gguf_tensor_file_map
[
translated_key
]
continue
if
any
(
target_key
in
key
for
target_key
in
tensor_from_gguf
):
target_tensor_map
[
key
]
=
gguf_tensor_file_map
[
translate_name
(
key
)]
else
:
target_tensor_map
[
key
]
=
safetensor_tensor_file_map
[
key
]
return
target_tensor_map
,
gguf_loader
def
write_combined_tensor
(
target_tensor_map
:
dict
,
output_path
:
str
,
gguf_loader
:
GGUFLoader
):
# Ensure output directory exists
os
.
makedirs
(
output_path
,
exist_ok
=
True
)
# Cache for safetensor file handles and GGUF loaders
safetensors_cache
=
{}
gguf_cache
=
{}
# Group tensors by layer
layer_groups
=
defaultdict
(
list
)
non_layer_keys
=
[]
layer_pattern
=
re
.
compile
(
r
'\.layers\.(\d+)\.'
)
for
key
in
target_tensor_map
:
match
=
layer_pattern
.
search
(
key
)
if
match
:
layer_num
=
int
(
match
.
group
(
1
))
layer_groups
[
layer_num
].
append
(
key
)
else
:
non_layer_keys
.
append
(
key
)
# Calculate total shards
total_shards
=
len
(
layer_groups
)
+
(
1
if
non_layer_keys
else
0
)
-
1
if
total_shards
==
0
:
raise
ValueError
(
"No tensors to save"
)
shard_idx
=
0
# Save non-layer tensors to the first shard if they exist
if
non_layer_keys
:
tensors
=
{}
for
key
in
non_layer_keys
:
file_path
=
target_tensor_map
[
key
]
tensor
=
None
ggml_type
=
None
if
file_path
.
endswith
(
'.safetensors'
):
if
file_path
not
in
safetensors_cache
:
safetensors_cache
[
file_path
]
=
safe_open
(
file_path
,
framework
=
'pt'
)
f
=
safetensors_cache
[
file_path
]
tensor
=
f
.
get_tensor
(
key
)
elif
file_path
.
endswith
(
'.gguf'
):
gguf_name
=
translate_name
(
key
)
tensor
,
ggml_type
=
gguf_loader
.
get_undequanted_tensor_and_ggml_type
(
gguf_name
)
else
:
raise
ValueError
(
f
"Unsupported file format:
{
file_path
}
"
)
tensors
[
translate_name
(
key
)]
=
tensor
if
ggml_type
:
ggml_type
=
torch
.
tensor
(
ggml_type
)
ggml_key
=
translate_name
(
key
)[:
-
7
]
+
".ggml_type"
if
translate_name
(
key
).
endswith
(
".weight"
)
else
translate_name
(
key
)
+
".ggml_type"
tensors
[
ggml_key
]
=
ggml_type
output_file
=
os
.
path
.
join
(
output_path
,
f
"model-
{
shard_idx
:
05
}
-of-
{
total_shards
:
05
}
.safetensors"
)
print
(
f
"Saving non-layer tensors to
{
output_file
}
"
)
save_file
(
tensors
,
output_file
)
print
(
tensors
.
keys
())
shard_idx
+=
1
# Save each layer's tensors to subsequent shards
for
layer_num
in
sorted
(
layer_groups
.
keys
()):
layer_keys
=
layer_groups
[
layer_num
]
tensors
=
{}
for
key
in
layer_keys
:
file_path
=
target_tensor_map
[
key
]
tensor
=
None
ggml_type
=
None
if
file_path
.
endswith
(
'.safetensors'
):
if
file_path
not
in
safetensors_cache
:
safetensors_cache
[
file_path
]
=
safe_open
(
file_path
,
framework
=
'pt'
)
f
=
safetensors_cache
[
file_path
]
tensor
=
f
.
get_tensor
(
key
)
tensor_info
=
tensor
.
shape
elif
file_path
.
endswith
(
'.gguf'
):
gguf_name
=
translate_name
(
key
)
tensor
,
ggml_type
=
gguf_loader
.
get_undequanted_tensor_and_ggml_type
(
gguf_name
)
# tensor_info = gguf_loader.tensor_info[gguf_name]
# ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']
else
:
raise
ValueError
(
f
"Unsupported file format:
{
file_path
}
"
)
tensors
[
translate_name
(
key
)]
=
tensor
if
ggml_type
:
ggml_type
=
torch
.
tensor
(
ggml_type
)
ggml_key
=
translate_name
(
key
)[:
-
7
]
+
".ggml_type"
if
translate_name
(
key
).
endswith
(
".weight"
)
else
translate_name
(
key
)
+
".ggml_type"
tensors
[
ggml_key
]
=
ggml_type
output_file
=
os
.
path
.
join
(
output_path
,
f
"model-
{
shard_idx
:
05
}
-of-
{
total_shards
:
05
}
.safetensors"
)
print
(
f
"Saving layer
{
layer_num
}
to
{
output_file
}
"
)
print
(
tensors
.
keys
())
save_file
(
tensors
,
output_file
)
shard_idx
+=
1
return
def
main
():
# 创建命令行参数解析器
parser
=
argparse
.
ArgumentParser
(
description
=
"Read parameters from Safetensor and GGUF files"
)
parser
.
add_argument
(
"--safetensor_path"
,
type
=
str
,
help
=
"Path to the Safetensor file"
,
default
=
"/mnt/data/model/DeepSeek-V3"
)
parser
.
add_argument
(
"--gguf_path"
,
type
=
str
,
help
=
"Path to the GGUF file"
,
default
=
"/mnt/data/model/DeepseekV3-q4km-gguf"
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
help
=
"Path to the output file"
,
default
=
"/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8"
)
# print all the arguments
print
(
"All the arguments:"
)
print
(
parser
.
parse_args
())
# 解析命令行参数
args
=
parser
.
parse_args
()
safetensor_path
=
args
.
safetensor_path
gguf_path
=
args
.
gguf_path
output_path
=
args
.
output_path
target_tensor_map
,
gguf_loader
=
combine_tensor_sources
(
safetensor_path
,
gguf_path
)
write_combined_tensor
(
target_tensor_map
,
output_path
,
gguf_loader
)
return
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment