Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
c189d55b
"vscode:/vscode.git/clone" did not exist on "59593890771d2c2b0efe9d156d40380be861dca5"
Commit
c189d55b
authored
Feb 15, 2025
by
Atream
Browse files
toy support for experts on GPU, no CUDA Graph
parent
1548c992
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
201 additions
and
67 deletions
+201
-67
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+14
-14
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+93
-43
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+13
-10
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
...optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
+49
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
...optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
+2
-0
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+30
-0
No files found.
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
c189d55b
...
@@ -17,8 +17,8 @@
...
@@ -17,8 +17,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
__global__
void
dequantize_q8_0_kernel
(
float
*
output
,
const
float
*
scales
,
const
int8_t
*
qs
,
int
num_blocks
,
int
blk_size
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
int
i
=
0
;
i
<
blk_size
;
i
++
){
for
(
int
i
=
0
;
i
<
blk_size
;
i
++
){
float
scale
=
scales
[
block_id
];
float
scale
=
scales
[
block_id
];
output
[
block_id
*
blk_size
+
i
]
=
scale
*
qs
[
block_id
*
blk_size
+
i
];
output
[
block_id
*
blk_size
+
i
]
=
scale
*
qs
[
block_id
*
blk_size
+
i
];
...
@@ -37,8 +37,8 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
...
@@ -37,8 +37,8 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q2_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
80
)));
...
@@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q3_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask1
=
0x03030303
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
const
uint32_t
kmask2
=
0x0f0f0f0f
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
uint32_t
aux
[
4
];
uint32_t
aux
[
4
];
...
@@ -128,8 +128,8 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -128,8 +128,8 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q4_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
// const uint8_t * q = data[i].qs;
// const uint8_t * q = data[i].qs;
const
uint8_t
*
q
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
16
);
const
uint8_t
*
q
=
(
uint8_t
*
)(
data
+
block_id
*
144
+
16
);
...
@@ -152,8 +152,8 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -152,8 +152,8 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__
void
dequantize_q5_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q5_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
0
)));
...
@@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_q6_k_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
){
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
208
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
+
208
)));
...
@@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
...
@@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
static
constexpr
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
static
constexpr
__device__
int8_t
kvalues_iq4nl
[
16
]
=
{
-
127
,
-
104
,
-
83
,
-
65
,
-
49
,
-
35
,
-
22
,
-
10
,
1
,
13
,
25
,
38
,
53
,
69
,
89
,
113
};
__global__
void
dequantize_iq4_xs_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
__global__
void
dequantize_iq4_xs_kernel
(
int8_t
*
data
,
float
*
output
,
int
blk_size
,
int
num_blocks
)
{
int
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
auto
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
long
long
block_id
=
global_idx
;
block_id
<
num_blocks
;
block_id
+=
blockDim
.
x
*
gridDim
.
x
)
{
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
float
*
__restrict__
output_blk
=
(
float
*
)(
output
+
block_id
*
256
);
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
)));
const
float
d
=
__half2float
(
*
(
reinterpret_cast
<
half
*>
(
data
+
block_id
*
blk_size
)));
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
const
uint16_t
scales_h
=
*
(
reinterpret_cast
<
uint16_t
*>
(
data
+
block_id
*
blk_size
+
2
));
...
...
ktransformers/operators/experts.py
View file @
c189d55b
...
@@ -18,6 +18,7 @@ import torch.nn.functional as F
...
@@ -18,6 +18,7 @@ import torch.nn.functional as F
import
torch
import
torch
import
sys
,
os
import
sys
,
os
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
tqdm
import
tqdm
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
...
@@ -225,6 +226,7 @@ class KExpertsCPU(KExpertsBase):
...
@@ -225,6 +226,7 @@ class KExpertsCPU(KExpertsBase):
return
return
def
load_weights
(
self
,
override_key
:
str
|
None
=
None
,
device
:
str
=
"cpu"
):
def
load_weights
(
self
,
override_key
:
str
|
None
=
None
,
device
:
str
=
"cpu"
):
# TODO: support Bias
res
=
{}
res
=
{}
if
override_key
is
not
None
:
if
override_key
is
not
None
:
keys
=
override_key
keys
=
override_key
...
@@ -288,6 +290,8 @@ class KExpertsMarlin(KExpertsBase):
...
@@ -288,6 +290,8 @@ class KExpertsMarlin(KExpertsBase):
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
assert
device
.
lower
()
!=
"cpu"
,
"Marlin experts can only be loaded on GPU"
assert
device
.
lower
()
!=
"cpu"
,
"Marlin experts can only be loaded on GPU"
self
.
device
=
device
self
.
device
=
device
self
.
elements_per_tensor
=
config
.
moe_intermediate_size
*
config
.
hidden_size
# create empty marlin experts according to the number of experts per token
# create empty marlin experts according to the number of experts per token
# up
# up
self
.
up_projs
=
[
KLinearMarlin
(
key
+
"."
+
"ffn_up_exps"
,
gguf_loader
,
config
,
device
=
device
)
for
i
in
range
(
self
.
expert_num
)]
self
.
up_projs
=
[
KLinearMarlin
(
key
+
"."
+
"ffn_up_exps"
,
gguf_loader
,
config
,
device
=
device
)
for
i
in
range
(
self
.
expert_num
)]
...
@@ -299,17 +303,34 @@ class KExpertsMarlin(KExpertsBase):
...
@@ -299,17 +303,34 @@ class KExpertsMarlin(KExpertsBase):
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
,
warmup
:
bool
=
False
):
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
,
warmup
:
bool
=
False
):
if
device
is
None
:
device
=
self
.
device
if
device
is
None
:
device
=
self
.
device
assert
device
.
lower
()
!=
"cpu"
,
"Marlin experts can only be loaded on GPU"
assert
device
.
lower
()
!=
"cpu"
,
"Marlin experts can only be loaded on GPU"
if
w
is
None
:
w
=
self
.
load_weights
()[
self
.
key
]
if
w
is
None
:
w
=
self
.
load_weights
()
if
isinstance
(
w
,
dict
):
load_by_experts
=
True
self
.
gate
=
w
[
"gate"
]
self
.
up
=
(
w
[
"up"
])
if
load_by_experts
:
self
.
down
=
(
w
[
"down"
])
if
isinstance
(
w
,
dict
):
for
i
in
range
(
self
.
expert_num
):
self
.
gate
=
w
[
"gate"
]
self
.
up_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
up
[
i
,...]),
device
=
device
)
self
.
up
=
(
w
[
"up"
])
self
.
gate_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
gate
[
i
,...]),
device
=
device
)
self
.
down
=
(
w
[
"down"
])
self
.
down_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
down
[
i
,...]),
device
=
device
)
for
i
in
tqdm
(
range
(
self
.
expert_num
),
desc
=
f
"Dequanting and quanting for KExpertsMarlin
{
self
.
key
}
"
):
self
.
loaded_experts_idx
.
append
(
i
)
up_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_up_exps.weight"
,
self
.
up
,
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
gate_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_gate_exps.weight"
,
self
.
gate
,
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
down_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_down_exps.weight"
,
self
.
down
,
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
self
.
up_projs
[
i
].
load
(
nn
.
Parameter
(
up_weights
),
device
=
device
)
self
.
gate_projs
[
i
].
load
(
nn
.
Parameter
(
gate_weights
),
device
=
device
)
self
.
down_projs
[
i
].
load
(
nn
.
Parameter
(
down_weights
),
device
=
device
)
self
.
loaded_experts_idx
.
append
(
i
)
else
:
if
isinstance
(
w
,
dict
):
self
.
gate
=
w
[
"gate"
]
self
.
up
=
(
w
[
"up"
])
self
.
down
=
(
w
[
"down"
])
for
i
in
range
(
self
.
expert_num
):
self
.
up_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
up
[
i
,...]),
device
=
device
)
self
.
gate_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
gate
[
i
,...]),
device
=
device
)
self
.
down_projs
[
i
].
load
(
nn
.
Parameter
(
self
.
down
[
i
,...]),
device
=
device
)
self
.
loaded_experts_idx
.
append
(
i
)
return
return
def
unload
(
self
):
def
unload
(
self
):
...
@@ -329,20 +350,13 @@ class KExpertsMarlin(KExpertsBase):
...
@@ -329,20 +350,13 @@ class KExpertsMarlin(KExpertsBase):
gate
=
None
gate
=
None
up
=
None
up
=
None
down
=
None
down
=
None
gate_type
=
None
up_type
=
None
down_type
=
None
for
key
in
keys
:
for
key
in
keys
:
if
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
if
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
gate
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
".ffn_gate_exps.weight"
)
gate
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_gate_exps.weight"
)
up
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
".ffn_up_exps.weight"
)
up
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_up_exps.weight"
)
down
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
".ffn_down_exps.weight"
)
down
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_down_exps.weight"
)
gate_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".ffn_gate_exps.weight"
][
"ggml_type"
]
res
=
{
"gate"
:
gate
,
"up"
:
up
,
"down"
:
down
}
up_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".ffn_up_exps.weight"
][
"ggml_type"
]
down_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".ffn_down_exps.weight"
][
"ggml_type"
]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res
=
{
key
:{
"gate"
:
nn
.
Parameter
(
gate
),
"up"
:
nn
.
Parameter
(
up
),
"down"
:
nn
.
Parameter
(
down
),
"gate_type"
:
gate_type
,
"up_type"
:
up_type
,
"down_type"
:
down_type
}}
return
res
return
res
def
forward
(
self
,
hidden_states_cpu
:
torch
.
Tensor
,
selected_experts_cpu
:
torch
.
Tensor
,
routing_weights_cpu
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states_cpu
:
torch
.
Tensor
,
selected_experts_cpu
:
torch
.
Tensor
,
routing_weights_cpu
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -381,6 +395,7 @@ class KExpertsMarlin(KExpertsBase):
...
@@ -381,6 +395,7 @@ class KExpertsMarlin(KExpertsBase):
return
final_hidden_states
.
to
(
dtype
=
org_dtype
,
device
=
org_device
)
return
final_hidden_states
.
to
(
dtype
=
org_dtype
,
device
=
org_device
)
# untested, CUDA OOM
class
KExpertsTorch
(
KExpertsBase
):
class
KExpertsTorch
(
KExpertsBase
):
expert_num
:
int
expert_num
:
int
loaded_experts_idx
:
list
[
int
]
loaded_experts_idx
:
list
[
int
]
...
@@ -402,19 +417,39 @@ class KExpertsTorch(KExpertsBase):
...
@@ -402,19 +417,39 @@ class KExpertsTorch(KExpertsBase):
# self.loaded_experts_idx = []
# self.loaded_experts_idx = []
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
device
=
device
self
.
device
=
device
self
.
gate
=
None
self
.
elements_per_tensor
=
config
.
moe_intermediate_size
*
config
.
hidden_size
self
.
up
=
None
self
.
gate
=
[
None
for
_
in
range
(
self
.
expert_num
)]
self
.
donw
=
None
self
.
up
=
[
None
for
_
in
range
(
self
.
expert_num
)]
self
.
down
=
[
None
for
_
in
range
(
self
.
expert_num
)]
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
dtype
=
torch
.
get_default_dtype
()
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
,
warmup
:
bool
=
False
):
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
,
warmup
:
bool
=
False
):
if
device
is
None
:
device
=
self
.
device
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)[
self
.
key
]
if
w
is
None
:
w
=
self
.
load_weights
()
if
isinstance
(
w
,
dict
):
load_by_experts
=
True
self
.
gate
=
w
[
"gate"
].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
up
=
w
[
"up"
].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
if
load_by_experts
:
self
.
down
=
w
[
"down"
].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
if
isinstance
(
w
,
dict
):
for
i
in
tqdm
(
range
(
self
.
expert_num
),
desc
=
f
"Dequanting for KExpertsTorch
{
self
.
key
}
"
):
up_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_up_exps.weight"
,
w
[
"up"
],
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
gate_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_gate_exps.weight"
,
w
[
"gate"
],
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
down_weights
=
self
.
gguf_loader
.
load_expert_tensor
(
self
.
key
+
".ffn_down_exps.weight"
,
w
[
"down"
],
i
,
self
.
elements_per_tensor
,
device
=
self
.
device
)
self
.
up
[
i
]
=
up_weights
self
.
gate
[
i
]
=
gate_weights
self
.
down
[
i
]
=
down_weights
else
:
if
isinstance
(
w
,
dict
):
for
i
in
range
(
self
.
expert_num
):
self
.
gate
[
i
]
=
w
[
"gate"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
up
[
i
]
=
w
[
"up"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
down
[
i
]
=
w
[
"down"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
up
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
self
.
gate
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
self
.
down
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
return
def
unload
(
self
):
def
unload
(
self
):
if
self
.
gate
is
not
None
:
if
self
.
gate
is
not
None
:
...
@@ -422,6 +457,25 @@ class KExpertsTorch(KExpertsBase):
...
@@ -422,6 +457,25 @@ class KExpertsTorch(KExpertsBase):
self
.
up
=
None
self
.
up
=
None
self
.
down
=
None
self
.
down
=
None
def
load_weights
(
self
,
override_key
:
str
|
None
=
None
):
res
=
{}
if
override_key
is
not
None
:
keys
=
override_key
else
:
keys
=
[
self
.
key
]
gate
=
None
up
=
None
down
=
None
for
key
in
keys
:
if
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
gate
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_gate_exps.weight"
)
up
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_up_exps.weight"
)
down
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_down_exps.weight"
)
res
=
{
"gate"
:
gate
,
"up"
:
up
,
"down"
:
down
}
return
res
def
forward
(
self
,
hidden_states_cpu
:
torch
.
Tensor
,
selected_experts_cpu
:
torch
.
Tensor
,
routing_weights_cpu
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states_cpu
:
torch
.
Tensor
,
selected_experts_cpu
:
torch
.
Tensor
,
routing_weights_cpu
:
torch
.
Tensor
)
->
torch
.
Tensor
:
org_device
=
hidden_states_cpu
.
device
org_device
=
hidden_states_cpu
.
device
...
@@ -582,7 +636,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
...
@@ -582,7 +636,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
(
y
=
(
self
.
moe_
on_cpuinfer
(
self
.
moe_
kexperts
(
hidden_states_expert
,
selected_experts_expert
,
routing_weights_expert
hidden_states_expert
,
selected_experts_expert
,
routing_weights_expert
)
)
.
view
(
*
orig_shape
)
.
view
(
*
orig_shape
)
...
@@ -601,8 +655,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
...
@@ -601,8 +655,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return
y
,
router_logits
return
y
,
router_logits
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
return
outs
return
outs
...
@@ -672,7 +725,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
...
@@ -672,7 +725,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
if
isinstance
(
self
.
experts
,
KExpertsBase
):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
self
.
moe_
on_cpuinfer
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
y
=
self
.
moe_
kexperts
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
elif
hidden_states
.
size
(
0
)
>
10
:
elif
hidden_states
.
size
(
0
)
>
10
:
# TODO may bugs here
# TODO may bugs here
y
=
(
y
=
(
...
@@ -692,8 +745,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
...
@@ -692,8 +745,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
return
y
return
y
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
return
outs
return
outs
...
@@ -773,7 +825,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -773,7 +825,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
if
isinstance
(
self
.
experts
,
KExpertsBase
):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
self
.
moe_
on_cpuinfer
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
y
=
self
.
moe_
kexperts
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
elif
hidden_states
.
size
(
0
)
>
10
:
elif
hidden_states
.
size
(
0
)
>
10
:
# TODO may bugs here
# TODO may bugs here
y
=
(
y
=
(
...
@@ -793,8 +845,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -793,8 +845,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
return
y
return
y
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
return
outs
return
outs
...
@@ -881,7 +932,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
...
@@ -881,7 +932,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
(
y
=
(
self
.
moe_
on_cpuinfer
(
self
.
moe_
kexperts
(
hidden_states_expert
,
selected_experts_expert
,
routing_weights_expert
hidden_states_expert
,
selected_experts_expert
,
routing_weights_expert
)
)
.
view
(
*
orig_shape
)
.
view
(
*
orig_shape
)
...
@@ -900,8 +951,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
...
@@ -900,8 +951,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
return
y
,
router_logits
return
y
,
router_logits
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
return
outs
return
outs
...
...
ktransformers/operators/linear.py
View file @
c189d55b
...
@@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
...
@@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
has_bias
=
False
self
.
has_bias
=
False
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
w
=
None
self
.
w
eight
=
None
self
.
has_bias
=
False
self
.
has_bias
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
...
@@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
out_device
=
x
.
device
out_device
=
x
.
device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
x
@
self
.
w
x
=
x
@
self
.
w
eight
if
self
.
has_bias
:
if
self
.
has_bias
:
x
=
x
+
self
.
bias
x
=
x
+
self
.
bias
x
=
x
.
to
(
dtype
=
dtype
,
device
=
out_device
)
x
=
x
.
to
(
dtype
=
dtype
,
device
=
out_device
)
...
@@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
...
@@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
if
isinstance
(
w
,
nn
.
Parameter
):
if
isinstance
(
w
,
nn
.
Parameter
):
try
:
try
:
self
.
w
=
w
.
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
self
.
w
eight
=
w
.
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
except
:
except
:
self
.
w
=
w
.
to
(
dtype
=
self
.
dtype
).
T
self
.
w
eight
=
w
.
to
(
dtype
=
self
.
dtype
).
T
self
.
has_bias
=
False
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
elif
isinstance
(
w
,
tuple
):
try
:
try
:
self
.
w
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
self
.
w
eight
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
except
:
except
:
self
.
w
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
T
self
.
w
eight
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
T
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
dtype
)
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
dtype
)
self
.
has_bias
=
True
self
.
has_bias
=
True
else
:
else
:
raise
ValueError
(
"Invalid weight type"
)
raise
ValueError
(
"Invalid weight type"
)
# self.linear = self.linear.to(device)
# self.linear = self.linear.to(device)
self
.
w
=
self
.
w
.
to
(
device
)
self
.
w
eight
=
self
.
w
eight
.
to
(
device
)
if
self
.
has_bias
:
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
self
.
bias
=
self
.
bias
.
to
(
device
)
def
unload
(
self
):
def
unload
(
self
):
if
self
.
w
is
not
None
:
if
self
.
w
eight
is
not
None
:
self
.
w
=
None
self
.
w
eight
=
None
if
self
.
has_bias
:
if
self
.
has_bias
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
...
@@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
self
.
workspace
=
MarlinWorkspace
(
self
.
workspace
=
MarlinWorkspace
(
self
.
out_features
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
self
.
device
self
.
out_features
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
self
.
device
)
)
self
.
weight
=
marlin_q_w
# modeling_xxx.py may use linear.weight
self
.
marlin_q_w
=
marlin_q_w
self
.
marlin_q_w
=
marlin_q_w
self
.
marlin_s
=
marlin_s
self
.
marlin_s
=
marlin_s
self
.
g_idx
=
g_idx
self
.
g_idx
=
g_idx
...
@@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
...
@@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
if
mode
==
InferenceState
.
PREFILL
:
if
mode
==
InferenceState
.
PREFILL
:
self
.
generate_linear
.
unload
()
self
.
generate_linear
.
unload
()
self
.
prefill_linear
.
load
(
w
=
w
)
self
.
prefill_linear
.
load
(
w
=
w
)
self
.
device
=
self
.
prefill_linear
.
device
self
.
device
=
self
.
prefill_linear
.
device
self
.
weight
=
self
.
prefill_linear
.
weight
# modeling_xxx.py may use linear.weight
elif
mode
==
InferenceState
.
GENERATE
:
elif
mode
==
InferenceState
.
GENERATE
:
self
.
prefill_linear
.
unload
()
self
.
prefill_linear
.
unload
()
self
.
generate_linear
.
load
(
w
=
w
)
self
.
generate_linear
.
load
(
w
=
w
)
self
.
device
=
self
.
generate_linear
.
device
self
.
device
=
self
.
generate_linear
.
device
self
.
weight
=
self
.
generate_linear
.
weight
# modeling_xxx.py may use linear.weight
elif
mode
==
InferenceState
.
UNLOAD
:
elif
mode
==
InferenceState
.
UNLOAD
:
self
.
prefill_linear
.
unload
()
self
.
prefill_linear
.
unload
()
self
.
generate_linear
.
unload
()
self
.
generate_linear
.
unload
()
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
View file @
c189d55b
...
@@ -182,6 +182,53 @@
...
@@ -182,6 +182,53 @@
generate_device
:
"
cuda:3"
generate_device
:
"
cuda:3"
prefill_device
:
"
cuda:3"
prefill_device
:
"
cuda:3"
# === MLP Experts Replacement ===
# replace with marlin expert. Open and modify layer-num as needed.
# Each layer of malin experts takes about 6GB of GPU memory.
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# # GPU 0: layers 3–4
# - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:0"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 1: layers 15–17
# - match:
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:1"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 2: layers 30–32
# - match:
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:2"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 3: layers 45–46
# - match:
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:3"
# generate_op: "KExpertsMarlin"
# recursive: False
# === MLP Experts Replacement ===
# === MLP Experts Replacement ===
# GPU 0: layers 0–14
# GPU 0: layers 0–14
...
@@ -316,6 +363,8 @@
...
@@ -316,6 +363,8 @@
generate_device
:
"
cuda:2"
generate_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
# don't inject lm_head if already inject marlin experts
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config)
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config)
-
match
:
-
match
:
name
:
"
(^model
\\
.layers
\\
.(4[5-9]|5[0-9]|60)
\\
.)|(^model
\\
.norm)|(^lm_head)"
name
:
"
(^model
\\
.layers
\\
.(4[5-9]|5[0-9]|60)
\\
.)|(^model
\\
.norm)|(^lm_head)"
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
View file @
c189d55b
...
@@ -713,6 +713,8 @@
...
@@ -713,6 +713,8 @@
generate_device
:
"
cuda:7"
generate_device
:
"
cuda:7"
prefill_device
:
"
cuda:7"
prefill_device
:
"
cuda:7"
# don't inject lm_head if already inject marlin experts
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config)
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config)
-
match
:
-
match
:
name
:
"
(^model
\\
.layers
\\
.(4[5-9]|5[0-9]|60)
\\
.)|(^model
\\
.norm)|(^lm_head)"
name
:
"
(^model
\\
.layers
\\
.(4[5-9]|5[0-9]|60)
\\
.)|(^model
\\
.norm)|(^lm_head)"
...
...
ktransformers/util/custom_gguf.py
View file @
c189d55b
...
@@ -276,8 +276,38 @@ class GGUFLoader:
...
@@ -276,8 +276,38 @@ class GGUFLoader:
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
def
load_expert_tensor
(
self
,
name
,
data
,
expert_id
,
elements_per_expert
,
device
=
"gpu"
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
print
(
f
"loading expert
{
expert_id
}
of
{
name
}
with CPU"
)
shape
=
t
[
"shape"
]
ggml_type
=
t
[
"ggml_type"
]
if
ggml_type
not
in
GGML_NAMES
:
raise
NotImplementedError
(
f
"ggml_type
{
ggml_type
}
not implemented"
)
ggml_name
=
GGML_NAMES
[
ggml_type
]
# TODO: experts may fused in quant block, split it
assert
elements_per_expert
%
GGML_ELEMENTS_PER_BLOCK
[
ggml_name
]
==
0
,
"experts may fused in quant block, please use CPU dequant"
blocks_per_experts
=
elements_per_expert
//
GGML_ELEMENTS_PER_BLOCK
[
ggml_name
]
block_size
=
GGML_BLOCK_SIZES
[
ggml_name
]
offset
=
expert_id
*
block_size
*
blocks_per_experts
data
=
data
[
offset
:
offset
+
block_size
*
blocks_per_experts
]
if
"cuda"
in
device
.
lower
():
values
=
GGML_DEQUANTIZE_GPU
[
ggml_name
](
data
,
device
)
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
values
=
values
.
view
(
shape
[
-
2
::
-
1
])
return
values
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
)
->
torch
.
Tensor
:
def
load_gguf_tensor
(
self
,
name
:
str
,
device
:
str
=
"cpu"
)
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
print
(
f
"loading
{
name
}
with CPU"
)
shape
=
t
[
"shape"
]
shape
=
t
[
"shape"
]
ggml_type
=
t
[
"ggml_type"
]
ggml_type
=
t
[
"ggml_type"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment