Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d5b0a0e2
"ppocr/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "7b7c8f3bb78532ebf471242786bf0f39c200f32a"
Unverified
Commit
d5b0a0e2
authored
Aug 04, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 04, 2020
Browse files
mBART Conversion script (#6230)
parent
268bf346
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
13 deletions
+36
-13
src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py
...rs/convert_bart_original_pytorch_checkpoint_to_pytorch.py
+0
-13
src/transformers/convert_mbart_original_checkpoint_to_pytorch.py
...nsformers/convert_mbart_original_checkpoint_to_pytorch.py
+36
-0
No files found.
src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py
View file @
d5b0a0e2
...
...
@@ -78,19 +78,6 @@ def load_xsum_checkpoint(checkpoint_path):
return
hub_interface
def
convert_checkpoint_from_disk
(
checkpoint_path
,
**
config_kwargs
):
state_dict
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
)[
"model"
]
remove_ignore_keys_
(
state_dict
)
vocab_size
=
state_dict
[
"encoder.embed_tokens.weight"
].
shape
[
0
]
state_dict
[
"shared.weight"
]
=
state_dict
[
"decoder.embed_tokens.weight"
]
mbart_config
=
BartConfig
(
vocab_size
=
vocab_size
,
**
config_kwargs
)
model
=
BartForConditionalGeneration
(
mbart_config
)
model
.
model
.
load_state_dict
(
state_dict
)
if
hasattr
(
model
,
"lm_head"
):
model
.
lm_head
=
_make_linear_from_emb
(
model
.
model
.
shared
)
return
model
@
torch
.
no_grad
()
def
convert_bart_checkpoint
(
checkpoint_path
,
pytorch_dump_folder_path
,
hf_checkpoint_name
=
None
):
"""
...
...
src/transformers/convert_mbart_original_checkpoint_to_pytorch.py
0 → 100644
View file @
d5b0a0e2
import
argparse
import
torch
from
transformers
import
BartForConditionalGeneration
,
MBartConfig
from
.convert_bart_original_pytorch_checkpoint_to_pytorch
import
remove_ignore_keys_
def
convert_fairseq_mbart_checkpoint_from_disk
(
checkpoint_path
,
hf_config_path
=
"facebook/mbart-large-en-ro"
):
state_dict
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
)[
"model"
]
remove_ignore_keys_
(
state_dict
)
vocab_size
=
state_dict
[
"encoder.embed_tokens.weight"
].
shape
[
0
]
mbart_config
=
MBartConfig
.
from_pretrained
(
hf_config_path
,
vocab_size
=
vocab_size
)
state_dict
[
"shared.weight"
]
=
state_dict
[
"decoder.embed_tokens.weight"
]
model
=
BartForConditionalGeneration
(
mbart_config
)
model
.
model
.
load_state_dict
(
state_dict
)
return
model
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"fairseq_path"
,
type
=
str
,
help
=
"bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser
.
add_argument
(
"pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--hf_config"
,
default
=
"facebook/mbart-large-cc25"
,
type
=
str
,
help
=
"Which huggingface architecture to use: bart-large-xsum"
,
)
args
=
parser
.
parse_args
()
model
=
convert_fairseq_mbart_checkpoint_from_disk
(
args
.
fairseq_path
,
hf_config_path
=
args
.
hf_config
)
model
.
save_pretrained
(
args
.
pytorch_dump_folder_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment