Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
0fca9576
Commit
0fca9576
authored
Jul 09, 2025
by
mxCynic
Browse files
fix: fix 9G4B model by add coefficent to weight tensor when load
parent
a5deda33
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
4 deletions
+35
-4
scripts/jiuge.py
scripts/jiuge.py
+35
-4
No files found.
scripts/jiuge.py
View file @
0fca9576
...
@@ -20,6 +20,7 @@ import safetensors
...
@@ -20,6 +20,7 @@ import safetensors
import
sys
import
sys
import
time
import
time
import
json
import
json
import
math
import
torch
import
torch
import
transformers
import
transformers
...
@@ -101,10 +102,28 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
...
@@ -101,10 +102,28 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
),
),
dh
=
config
[
"hidden_size"
]
//
config
[
"num_attention_heads"
],
dh
=
config
[
"hidden_size"
]
//
config
[
"num_attention_heads"
],
di
=
config
[
"intermediate_size"
],
di
=
config
[
"intermediate_size"
],
dctx
=
config
[
"max_position_embeddings"
]
if
max_tokens
is
None
else
max_tokens
,
dctx
=
(
config
[
"max_position_embeddings"
]
if
max_tokens
is
None
else
max_tokens
),
dvoc
=
config
[
"vocab_size"
],
dvoc
=
config
[
"vocab_size"
],
epsilon
=
config
[
"rms_norm_eps"
],
epsilon
=
config
[
"rms_norm_eps"
],
theta
=
(
config
[
"rope_theta"
]
if
"rope_theta"
in
config
else
100000.0
),
theta
=
(
config
[
"rope_theta"
]
if
"rope_theta"
in
config
else
100000.0
),
scale_input
=
(
config
[
"scale_emb"
]
if
"scale_emb"
in
config
else
1.0
),
scale_output
=
(
config
[
"hidden_size"
]
//
config
[
"dim_model_base"
]
if
"dim_model_base"
in
config
else
1.0
),
scale_o
=
(
config
[
"scale_depth"
]
/
math
.
sqrt
(
config
[
"num_hidden_layers"
])
if
"scale_depth"
in
config
else
1.0
),
scale_down
=
(
config
[
"scale_depth"
]
/
math
.
sqrt
(
config
[
"num_hidden_layers"
])
if
"scale_depth"
in
config
else
1.0
),
end_token
=
2
,
end_token
=
2
,
)
)
self
.
torch_dtype_logits
=
dtype
self
.
torch_dtype_logits
=
dtype
...
@@ -127,6 +146,10 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
...
@@ -127,6 +146,10 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
dh
=
meta
.
dh
dh
=
meta
.
dh
d
=
meta
.
d
d
=
meta
.
d
di
=
meta
.
di
di
=
meta
.
di
scale_input
=
meta
.
scale_input
scale_output
=
meta
.
scale_output
scale_o
=
meta
.
scale_o
scale_down
=
meta
.
scale_down
assert
nh
%
nkvh
==
0
assert
nh
%
nkvh
==
0
assert
nh
%
ndev
==
0
assert
nh
%
ndev
==
0
assert
nkvh
%
ndev
==
0
assert
nkvh
%
ndev
==
0
...
@@ -161,9 +184,13 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
...
@@ -161,9 +184,13 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
)
)
self
.
transpose_linear_weights
=
1
if
transpose_weight
else
0
self
.
transpose_linear_weights
=
1
if
transpose_weight
else
0
self
.
nlayer
=
nlayer
self
.
nlayer
=
nlayer
self
.
input_embd_tensor
=
state_dict
[
input_embd_naming
].
to
(
torch_dt_logits
)
self
.
input_embd_tensor
=
(
state_dict
[
input_embd_naming
].
to
(
torch_dt_logits
)
*
scale_input
)
self
.
input_embd
=
self
.
input_embd_tensor
.
data_ptr
()
self
.
input_embd
=
self
.
input_embd_tensor
.
data_ptr
()
self
.
output_norm_tensor
=
state_dict
[
naming
.
output_norm
()].
to
(
torch_dt_norm
)
self
.
output_norm_tensor
=
(
state_dict
[
naming
.
output_norm
()].
to
(
torch_dt_norm
)
*
scale_output
)
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_embd_tensor
=
state_dict
[
output_embd_naming
].
to
(
torch_dt_mat
)
self
.
output_embd_tensor
=
state_dict
[
output_embd_naming
].
to
(
torch_dt_mat
)
if
not
transpose_weight
:
if
not
transpose_weight
:
...
@@ -260,6 +287,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
...
@@ -260,6 +287,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
.
to
(
torch_dt_mat
)
.
to
(
torch_dt_mat
)
.
contiguous
()
.
contiguous
()
)
)
*
scale_o
for
i
in
range
(
nlayer
)
for
i
in
range
(
nlayer
)
]
]
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
...
@@ -310,6 +338,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
...
@@ -310,6 +338,7 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
.
to
(
torch_dt_mat
)
.
to
(
torch_dt_mat
)
.
contiguous
()
.
contiguous
()
)
)
*
scale_down
for
i
in
range
(
nlayer
)
for
i
in
range
(
nlayer
)
]
]
self
.
ffn_down_ptrs
=
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_down_ptrs
=
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
...
@@ -358,7 +387,9 @@ class JiugeBatchedTask:
...
@@ -358,7 +387,9 @@ class JiugeBatchedTask:
class
JiugeForCauslLM
:
class
JiugeForCauslLM
:
def
__init__
(
self
,
model_dir_path
,
device
=
DeviceType
.
DEVICE_TYPE_CPU
,
ndev
=
1
,
max_tokens
=
None
):
def
__init__
(
self
,
model_dir_path
,
device
=
DeviceType
.
DEVICE_TYPE_CPU
,
ndev
=
1
,
max_tokens
=
None
):
def
load_all_safetensors_from_dir
(
dir_path_
:
str
):
def
load_all_safetensors_from_dir
(
dir_path_
:
str
):
tensors_
=
{}
tensors_
=
{}
dir_path_
=
Path
(
dir_path_
)
dir_path_
=
Path
(
dir_path_
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment