Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
c189d55b
Commit
c189d55b
authored
Feb 15, 2025
by
Atream
Browse files
toy support for experts on GPU, no CUDA Graph
parent
1548c992
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
201 additions
and
67 deletions
+201
-67
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+14
-14
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+93
-43
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+13
-10
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
...optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
+49
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
...optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
+2
-0
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+30
-0
No files found.
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
c189d55b
...
@@ -17,8 +17,8 @@
...
@@ -17,8 +17,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
int
i
=
0
;
i
<
blk_size
;
i
++
){
for
(
int
i
=
0
;
i
<
blk_size
;
i
++
){
float
scale
=
scales
[
block_id
];
float
scale
=
scales
[
block_id
];
output
[
block_id
*
blk_size
+
i
]
=
scale
*
qs
[
block_id
*
blk_size
+
i
];
output
[
block_id
*
blk_size
+
i
]
=
scale
*
qs
[
block_id
*
blk_size
+
i
];
...
@@ -37,8 +37,8 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
...
@@ -37,8 +37,8 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
...
@@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
uint32_t
aux
[
4
];
uint32_t
aux
[
4
];
...
@@ -128,8 +128,8 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -128,8 +128,8 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
// const uint8_t * q = data[i].qs;
// const uint8_t * q = data[i].qs;
const
uint8_t
*
q
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
16
);
const
uint8_t
*
q
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
16
);
...
@@ -152,8 +152,8 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -152,8 +152,8 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__
void
dequantize_q5_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q5_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
...
@@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
208
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
208
)));
...
@@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
static
constexpr
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
static
constexpr
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
__global__
void
dequantize_iq4_xs_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_iq4_xs_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
)));
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
...
...
ktransformers/operators/experts.py
View file @
c189d55b
...
@@ -18,6 +18,7 @@ import torch.nn.functional as F
...
@@ -18,6 +18,7 @@ import torch.nn.functional as F
import
torch
import
torch
import
sys
,
os
import
sys
,
os
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
tqdm
import
tqdm
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
...
@@ -225,6 +226,7 @@ class KExpertsCPU(KExpertsBase):
...
@@ -225,6 +226,7 @@ class KExpertsCPU(KExpertsBase):
return
return
def
load_weights
(
self
,
override_key
:
str
|
None
=
None
,
device
:
str
=
"cpu"
):
def
load_weights
(
self
,
override_key
:
str
|
None
=
None
,
device
:
str
=
"cpu"
):
# TODO: support Bias
res
=
{}
res
=
{}
if
override_key
is
not
None
:
if
override_key
is
not
None
:
keys
=
override_key
keys
=
override_key
...
@@ -288,6 +290,8 @@ class KExpertsMarlin(KExpertsBase):
...
@@ -288,6 +290,8 @@ class KExpertsMarlin(KExpertsBase):
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
assert
device
.
lower
()
!=
"cpu"
,
"Marlin experts can only be loaded on GPU"
assert
device
.
lower
()
!=
"cpu"
,
"Marlin experts can only be loaded on GPU"
self
.
device
=
device
self
.
device
=
device
self
.
elements_per_tensor
=
config
.
moe_intermediate_size
*
config
.
hidden_size
# create empty marlin experts according to the number of experts per token
# create empty marlin experts according to the number of experts per token
# up
# up
self
.
up_projs
=
[
KLinearMarlin
(
key
+
"."
+
"ffn_up_exps"
,
gguf_loader
,
config
,
device
=
device
)
for
i
in
range
(
self
.
expert_num
)]
self
.
up_projs
=
[
KLinearMarlin
(
key
+
"."
+
"ffn_up_exps"
,
gguf_loader
,
config
,
device
=
device
)
for
i
in
range
(
self
.
expert_num
)]
...
@@ -299,17 +303,34 @@ class KExpertsMarlin(KExpertsBase):
...
@@ -299,17 +303,34 @@ class KExpertsMarlin(KExpertsBase):
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
,
warmup
:
bool
=
False
):
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
,
warmup
:
bool
=
False
):
if
device
is
None
:
device
=
self
.
device
if
device
is
None
:
device
=
self
.
device
assert
device
.
lower
()
!=
"cpu"
,
"Marlin experts can only be loaded on GPU"
assert
device
.
lower
()
!=
"cpu"
,
"Marlin experts can only be loaded on GPU"
if
w
is
None
:
w
=
self
.
load_weights
()[
self
.
key
]
if
w
is
None
:
w
=
self
.
load_weights
()
if
isinstance
(
w
,
dict
):
load_by_experts
=
True
self
.
gate
=
w
[
"gate"
]
self
.
up
=
(
w
[
"up"
])
if
load_by_experts
:
self
.
down
=
(
w
[
"down"
])
if
isinstance
(
w
,
dict
):
for
i
in
range
(
self
.
expert_num
):
self
.
gate
=
w
[
"gate"
]
self
.
up_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
up
[
i
,...]),
device
=
device
)
self
.
up
=
(
w
[
"up"
])
self
.
gate_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
gate
[
i
,...]),
device
=
device
)
self
.
down
=
(
w
[
"down"
])
self
.
down_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
down
[
i
,...]),
device
=
device
)
for
i
in
tqdm
(
range
(
self
.
expert_num
),
desc
=
f
"Dequanting and quanting for KExpertsMarlin
{
self
.
key
}
"
):
self
.
loaded_experts_idx
.
append
(
i
)
up_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_up_exps.weight"
,
self
.
up
,
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
gate_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_gate_exps.weight"
,
self
.
gate
,
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
down_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_down_exps.weight"
,
self
.
down
,
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
self
.
up_projs
[
i
].
load
(
nn
.
Parameter
(
up_weights
),
device
=
device
)
self
.
gate_projs
[
i
].
load
(
nn
.
Parameter
(
gate_weights
),
device
=
device
)
self
.
down_projs
[
i
].
load
(
nn
.
Parameter
(
down_weights
),
device
=
device
)
self
.
loaded_experts_idx
.
append
(
i
)
else
:
if
isinstance
(
w
,
dict
):
self
.
gate
=
w
[
"gate"
]
self
.
up
=
(
w
[
"up"
])
self
.
down
=
(
w
[
"down"
])
for
i
in
range
(
self
.
expert_num
):
self
.
up_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
up
[
i
,...]),
device
=
device
)
self
.
gate_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
gate
[
i
,...]),
device
=
device
)
self
.
down_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
down
[
i
,...]),
device
=
device
)
self
.
loaded_experts_idx
.
append
(
i
)
return
return
def
unload
(
self
):
def
unload
(
self
):
...
@@ -329,20 +350,13 @@ class KExpertsMarlin(KExpertsBase):
...
@@ -329,20 +350,13 @@ class KExpertsMarlin(KExpertsBase):
gate
=
None
gate
=
None
up
=
None
up
=
None
down
=
None
down
=
None
gate_type
=
None
up_type
=
None
down_type
=
None
for
key
in
keys
:
for
key
in
keys
:
if
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
if
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
gate
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
".ffn_gate_exps.weight"
)
gate
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_gate_exps.weight"
)
up
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
".ffn_up_exps.weight"
)
up
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_up_exps.weight"
)
down
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
".ffn_down_exps.weight"
)
down
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_down_exps.weight"
)
gate_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".ffn_gate_exps.weight"
][
"ggml_type"
]
res
=
{
"gate"
:
gate
,
"up"
:
up
,
"down"
:
down
}
up_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".ffn_up_exps.weight"
][
"ggml_type"
]
down_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".ffn_down_exps.weight"
][
"ggml_type"
]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res
=
{
key
:{
"gate"
:
nn
.
Parameter
(
gate
),
"up"
:
nn
.
Parameter
(
up
),
"down"
:
nn
.
Parameter
(
down
),
"gate_type"
:
gate_type
,
"up_type"
:
up_type
,
"down_type"
:
down_type
}}
return
res
return
res
def
forward
(
self
,
hidden_states_cpu
:
torch
.
Tensor
,
selected_experts_cpu
:
torch
.
Tensor
,
routing_weights_cpu
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states_cpu
:
torch
.
Tensor
,
selected_experts_cpu
:
torch
.
Tensor
,
routing_weights_cpu
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -381,6 +395,7 @@ class KExpertsMarlin(KExpertsBase):
...
@@ -381,6 +395,7 @@ class KExpertsMarlin(KExpertsBase):
return
final_hidden_states
.
to
(
dtype
=
org_dtype
,
device
=
org_device
)
return
final_hidden_states
.
to
(
dtype
=
org_dtype
,
device
=
org_device
)
# untested, CUDA OOM
class
KExpertsTorch
(
KExpertsBase
):
class
KExpertsTorch
(
KExpertsBase
):
expert_num
:
int
expert_num
:
int
loaded_experts_idx
:
list
[
int
]
loaded_experts_idx
:
list
[
int
]
...
@@ -402,19 +417,39 @@ class KExpertsTorch(KExpertsBase):
...
@@ -402,19 +417,39 @@ class KExpertsTorch(KExpertsBase):
# self.loaded_experts_idx = []
# self.loaded_experts_idx = []
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
device
=
device
self
.
device
=
device
self
.
gate
=
None
self
.
elements_per_tensor
=
config
.
moe_intermediate_size
*
config
.
hidden_size
self
.
up
=
None
self
.
gate
=
[
None
for
_
in
range
(
self
.
expert_num
)]
self
.
donw
=
None
self
.
up
=
[
None
for
_
in
range
(
self
.
expert_num
)]
self
.
down
=
[
None
for
_
in
range
(
self
.
expert_num
)]
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
dtype
=
torch
.
get_default_dtype
()
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
,
warmup
:
bool
=
False
):
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
,
warmup
:
bool
=
False
):
if
device
is
None
:
device
=
self
.
device
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)[
self
.
key
]
if
w
is
None
:
w
=
self
.
load_weights
()
if
isinstance
(
w
,
dict
):
load_by_experts
=
True
self
.
gate
=
w
[
"gate"
].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
up
=
w
[
"up"
].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
if
load_by_experts
:
self
.
down
=
w
[
"down"
].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
if
isinstance
(
w
,
dict
):
for
i
in
tqdm
(
range
(
self
.
expert_num
),
desc
=
f
"Dequanting for KExpertsTorch
{
self
.
key
}
"
):
up_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_up_exps.weight"
,
w
[
"up"
],
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
gate_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_gate_exps.weight"
,
w
[
"gate"
],
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
down_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_down_exps.weight"
,
w
[
"down"
],
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
self
.
up
[
i
]
=
up_weights
self
.
gate
[
i
]
=
gate_weights
self
.
down
[
i
]
=
down_weights
else
:
if
isinstance
(
w
,
dict
):
for
i
in
range
(
self
.
expert_num
):
self
.
gate
[
i
]
=
w
[
"gate"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
up
[
i
]
=
w
[
"up"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
down
[
i
]
=
w
[
"down"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
up
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
self
.
gate
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
self
.
down
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
return
def
unload
(
self
):
def
unload
(
self
):
if
self
.
gate
is
not
None
:
if
self
.
gate
is
not
None
:
...
@@ -422,6 +457,25 @@ class KExpertsTorch(KExpertsBase):
...
@@ -422,6 +457,25 @@ class KExpertsTorch(KExpertsBase):
self
.
up
=
None
self
.
up
=
None
self
.
down
=
None
self
.
down
=
None
def
load_weights
(
self
,
override_key
:
str
|
None
=
None
):
res
=
{}
if
override_key
is
not
None
:
keys
=
override_key
else
:
keys
=
[
self
.
key
]
gate
=
None
up
=
None
down
=
None
for
key
in
keys
:
if
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
gate
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_gate_exps.weight"
)
up
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_up_exps.weight"
)
down
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_down_exps.weight"
)
res
=
{
"gate"
:
gate
,
"up"
:
up
,
"down"
:
down
}
return
res
def
forward
(
self
,
hidden_states_cpu
:
torch
.
Tensor
,
selected_experts_cpu
:
torch
.
Tensor
,
routing_weights_cpu
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states_cpu
:
torch
.
Tensor
,
selected_experts_cpu
:
torch
.
Tensor
,
routing_weights_cpu
:
torch
.
Tensor
)
->
torch
.
Tensor
:
org_device
=
hidden_states_cpu
.
device
org_device
=
hidden_states_cpu
.
device
...
@@ -582,7 +636,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
...
@@ -582,7 +636,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
(
y
=
(
self
.
moe_
on_cpuinfer
(
self
.
moe_
kexperts
(
hidden_states_expert
,
selected_experts_expert
,
routing_weights_expert
hidden_states_expert
,
selected_experts_expert
,
routing_weights_expert
)
)
.
view
(
*
orig_shape
)
.
view
(
*
orig_shape
)
...
@@ -601,8 +655,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
...
@@ -601,8 +655,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return
y
,
router_logits
return
y
,
router_logits
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
return
outs
return
outs
...
@@ -672,7 +725,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
...
@@ -672,7 +725,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
if
isinstance
(
self
.
experts
,
KExpertsBase
):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
self
.
moe_
on_cpuinfer
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
y
=
self
.
moe_
kexperts
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
elif
hidden_states
.
size
(
0
)
>
10
:
elif
hidden_states
.
size
(
0
)
>
10
:
# TODO may bugs here
# TODO may bugs here
y
=
(
y
=
(
...
@@ -692,8 +745,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
...
@@ -692,8 +745,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
return
y
return
y
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
return
outs
return
outs
...
@@ -773,7 +825,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -773,7 +825,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
if
isinstance
(
self
.
experts
,
KExpertsBase
):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
self
.
moe_
on_cpuinfer
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
y
=
self
.
moe_
kexperts
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
elif
hidden_states
.
size
(
0
)
>
10
:
elif
hidden_states
.
size
(
0
)
>
10
:
# TODO may bugs here
# TODO may bugs here
y
=
(
y
=
(
...
@@ -793,8 +845,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -793,8 +845,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
return
y
return
y
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
return
outs
return
outs
...
@@ -881,7 +932,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
...
@@ -881,7 +932,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
(
y
=
(
self
.
moe_
on_cpuinfer
(
self
.
moe_
kexperts
(
hidden_states_expert
,
selected_experts_expert
,
routing_weights_expert
hidden_states_expert
,
selected_experts_expert
,
routing_weights_expert
)
)
.
view
(
*
orig_shape
)
.
view
(
*
orig_shape
)
...
@@ -900,8 +951,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
...
@@ -900,8 +951,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
return
y
,
router_logits
return
y
,
router_logits
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
return
outs
return
outs
...
...
ktransformers/operators/linear.py
View file @
c189d55b
...
@@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
...
@@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
has_bias
=
False
self
.
has_bias
=
False
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
w
=
None
self
.
w
eight
=
None
self
.
has_bias
=
False
self
.
has_bias
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
...
@@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
out_device
=
x
.
device
out_device
=
x
.
device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
x
@
self
.
w
x
=
x
@
self
.
w
eight
if
self
.
has_bias
:
if
self
.
has_bias
:
x
=
x
+
self
.
bias
x
=
x
+
self
.
bias
x
=
x
.
to
(
dtype
=
dtype
,
device
=
out_device
)
x
=
x
.
to
(
dtype
=
dtype
,
device
=
out_device
)
...
@@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
...
@@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
if
isinstance
(
w
,
nn
.
Parameter
):
if
isinstance
(
w
,
nn
.
Parameter
):
try
:
try
:
self
.
w
=
w
.
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
self
.
w
eight
=
w
.
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
except
:
except
:
self
.
w
=
w
.
to
(
dtype
=
self
.
dtype
).
T
self
.
w
eight
=
w
.
to
(
dtype
=
self
.
dtype
).
T
self
.
has_bias
=
False
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
elif
isinstance
(
w
,
tuple
):
try
:
try
:
self
.
w
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
self
.
w
eight
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
except
:
except
:
self
.
w
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
T
self
.
w
eight
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
T
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
dtype
)
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
dtype
)
self
.
has_bias
=
True
self
.
has_bias
=
True
else
:
else
:
raise
ValueError
(
"Invalid weight type"
)
raise
ValueError
(
"Invalid weight type"
)
# self.linear = self.linear.to(device)
# self.linear = self.linear.to(device)
self
.
w
=
self
.
w
.
to
(
device
)
self
.
w
eight
=
self
.
w
eight
.
to
(
device
)
if
self
.
has_bias
:
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
self
.
bias
=
self
.
bias
.
to
(
device
)
def
unload
(
self
):
def
unload
(
self
):
if
self
.
w
is
not
None
:
if
self
.
w
eight
is
not
None
:
self
.
w
=
None
self
.
w
eight
=
None
if
self
.
has_bias
:
if
self
.
has_bias
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
...
@@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
self
.
workspace
=
MarlinWorkspace
(
self
.
workspace
=
MarlinWorkspace
(
self
.
out_features
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
self
.
device
self
.
out_features
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
self
.
device
)
)
self
.
weight
=
marlin_q_w
# modeling_xxx.py may use linear.weight
self
.
marlin_q_w
=
marlin_q_w
self
.
marlin_q_w
=
marlin_q_w
self
.
marlin_s
=
marlin_s
self
.
marlin_s
=
marlin_s
self
.
g_idx
=
g_idx
self
.
g_idx
=
g_idx
...
@@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
...
@@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
if
mode
==
InferenceState
.
PREFILL
:
if
mode
==
InferenceState
.
PREFILL
:
self
.
generate_linear
.
unload
()
self
.
generate_linear
.
unload
()
self
.
prefill_linear
.
load
(
w
=
w
)
self
.
prefill_linear
.
load
(
w
=
w
)
self
.
device
=
self
.
prefill_linear
.
device
self
.
device
=
self
.
prefill_linear
.
device
self
.
weight
=
self
.
prefill_linear
.
weight
# modeling_xxx.py may use linear.weight
elif
mode
==
InferenceState
.
GENERATE
:
elif
mode
==
InferenceState
.
GENERATE
:
self
.
prefill_linear
.
unload
()
self
.
prefill_linear
.
unload
()
self
.
generate_linear
.
load
(
w
=
w
)
self
.
generate_linear
.
load
(
w
=
w
)
self
.
device
=
self
.
generate_linear
.
device
self
.
device
=
self
.
generate_linear
.
device
self
.
weight
=
self
.
generate_linear
.
weight
# modeling_xxx.py may use linear.weight
elif
mode
==
InferenceState
.
UNLOAD
:
elif
mode
==
InferenceState
.
UNLOAD
:
self
.
prefill_linear
.
unload
()
self
.
prefill_linear
.
unload
()
self
.
generate_linear
.
unload
()
self
.
generate_linear
.
unload
()
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
View file @
c189d55b
...
@@ -182,6 +182,53 @@
...
@@ -182,6 +182,53 @@
generate_device
:
"
cuda:3"
generate_device
:
"
cuda:3"
prefill_device
:
"
cuda:3"
prefill_device
:
"
cuda:3"
# === MLP Experts Replacement ===
# replace with marlin expert. Open and modify layer-num as needed.
# Each layer of malin experts takes about 6GB of GPU memory.
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# # GPU 0: layers 3–4
# - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:0"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 1: layers 15–17
# - match:
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:1"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 2: layers 30–32
# - match:
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:2"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 3: layers 45–46
# - match:
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:3"
# generate_op: "KExpertsMarlin"
# recursive: False
# === MLP Experts Replacement ===
# === MLP Experts Replacement ===
# GPU 0: layers 0–14
# GPU 0: layers 0–14
...
@@ -316,6 +363,8 @@
...
@@ -316,6 +363,8 @@
generate_device
:
"
cuda:2"
generate_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
# don't inject lm_head if already inject marlin experts
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config)
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config)
-
match
:
-
match
:
name
:
"
(^model
\\
.layers
\\
.(4[5-9]|5[0-9]|60)
\\
.)|(^model
\\
.norm)|(^lm_head)"
name
:
"
(^model
\\
.layers
\\
.(4[5-9]|5[0-9]|60)
\\
.)|(^model
\\
.norm)|(^lm_head)"
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
View file @
c189d55b
...
@@ -713,6 +713,8 @@
...
@@ -713,6 +713,8 @@
generate_device
:
"
cuda:7"
generate_device
:
"
cuda:7"
prefill_device
:
"
cuda:7"
prefill_device
:
"
cuda:7"
# don't inject lm_head if already inject marlin experts
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config)
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config)
-
match
:
-
match
:
name
:
"
(^model
\\
.layers
\\
.(4[5-9]|5[0-9]|60)
\\
.)|(^model
\\
.norm)|(^lm_head)"
name
:
"
(^model
\\
.layers
\\
.(4[5-9]|5[0-9]|60)
\\
.)|(^model
\\
.norm)|(^lm_head)"
...
...
ktransformers/util/custom_gguf.py
View file @
c189d55b
...
@@ -276,8 +276,38 @@ class GGUFLoader:
...
@@ -276,8 +276,38 @@ class GGUFLoader:
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
def
load_expert_tensor
(
self
,
name
,
data
,
expert_id
,
elements_per_expert
,
device
=
"gpu"
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
print
(
f
"loading expert
{
expert_id
}
of
{
name
}
with CPU"
)
shape
=
t
[
"shape"
]
ggml_type
=
t
[
"ggml_type"
]
if
ggml_type
not
in
GGML_NAMES
:
raise
NotImplementedError
(
f
"ggml_type
{
ggml_type
}
not implemented"
)
ggml_name
=
GGML_NAMES
[
ggml_type
]
# TODO: experts may fused in quant block, split it
assert
elements_per_expert
%
GGML_ELEMENTS_PER_BLOCK
[
ggml_name
]
==
0
,
"experts may fused in quant block, please use CPU dequant"
blocks_per_experts
=
elements_per_expert
//
GGML_ELEMENTS_PER_BLOCK
[
ggml_name
]
block_size
=
GGML_BLOCK_SIZES
[
ggml_name
]
offset
=
expert_id
*
block_size
*
blocks_per_experts
data
=
data
[
offset
:
offset
+
block_size
*
blocks_per_experts
]
if
"cuda"
in
device
.
lower
():
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
values
=
values
.
view
(
shape
[
-
2
::
-
1
])
return
values
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
)
->
torch
.
Tensor
:
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
print
(
f
"loading
{
name
}
with CPU"
)
shape
=
t
[
"shape"
]
shape
=
t
[
"shape"
]
ggml_type
=
t
[
"ggml_type"
]
ggml_type
=
t
[
"ggml_type"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment