Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e0b617d1
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8af25b166486ec0cedbd2ef9147c3700dba88e0b"
Unverified
Commit
e0b617d1
authored
Dec 08, 2023
by
Pedro Cuenca
Committed by
GitHub
Dec 08, 2023
Browse files
Llama conversion script: adjustments for Llama Guard (#27910)
parent
e3669375
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
1 deletion
+2
-1
src/transformers/models/llama/convert_llama_weights_to_hf.py
src/transformers/models/llama/convert_llama_weights_to_hf.py
+2
-1
No files found.
src/transformers/models/llama/convert_llama_weights_to_hf.py
View file @
e0b617d1
...
@@ -91,6 +91,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
...
@@ -91,6 +91,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
params
=
read_json
(
os
.
path
.
join
(
input_base_path
,
"params.json"
))
params
=
read_json
(
os
.
path
.
join
(
input_base_path
,
"params.json"
))
num_shards
=
NUM_SHARDS
[
model_size
]
num_shards
=
NUM_SHARDS
[
model_size
]
params
=
params
.
get
(
"model"
,
params
)
n_layers
=
params
[
"n_layers"
]
n_layers
=
params
[
"n_layers"
]
n_heads
=
params
[
"n_heads"
]
n_heads
=
params
[
"n_heads"
]
n_heads_per_shard
=
n_heads
//
num_shards
n_heads_per_shard
=
n_heads
//
num_shards
...
@@ -109,7 +110,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
...
@@ -109,7 +110,7 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
tokenizer
.
save_pretrained
(
model_path
)
tokenizer
.
save_pretrained
(
model_path
)
vocab_size
=
tokenizer
.
vocab_size
if
tokenizer_path
is
not
None
else
32000
vocab_size
=
tokenizer
.
vocab_size
if
tokenizer_path
is
not
None
else
32000
if
"n_kv_heads"
in
params
:
if
params
.
get
(
"n_kv_heads"
,
None
)
is
not
None
:
num_key_value_heads
=
params
[
"n_kv_heads"
]
# for GQA / MQA
num_key_value_heads
=
params
[
"n_kv_heads"
]
# for GQA / MQA
num_local_key_value_heads
=
n_heads_per_shard
//
num_key_value_heads
num_local_key_value_heads
=
n_heads_per_shard
//
num_key_value_heads
key_value_dim
=
dim
//
num_key_value_heads
key_value_dim
=
dim
//
num_key_value_heads
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment