Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cad88e19
Unverified
Commit
cad88e19
authored
Jun 14, 2019
by
Thomas Wolf
Committed by
GitHub
Jun 14, 2019
Browse files
Merge pull request #672 from oliverguhr/master
Add vocabulary and model config to the finetune output
parents
c6de6252
5c08c8c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
3 deletions
+14
-3
examples/lm_finetuning/finetune_on_pregenerated.py
examples/lm_finetuning/finetune_on_pregenerated.py
+9
-2
examples/lm_finetuning/simple_lm_finetuning.py
examples/lm_finetuning/simple_lm_finetuning.py
+5
-1
No files found.
examples/lm_finetuning/finetune_on_pregenerated.py
View file @
cad88e19
from
argparse
import
ArgumentParser
from
argparse
import
ArgumentParser
from
pathlib
import
Path
from
pathlib
import
Path
import
os
import
torch
import
torch
import
logging
import
logging
import
json
import
json
...
@@ -12,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
...
@@ -12,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
pytorch_pretrained_bert
import
WEIGHTS_NAME
,
CONFIG_NAME
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
WarmupLinearSchedule
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
WarmupLinearSchedule
...
@@ -325,8 +327,13 @@ def main():
...
@@ -325,8 +327,13 @@ def main():
# Save a trained model
# Save a trained model
logging
.
info
(
"** ** * Saving fine-tuned model ** ** * "
)
logging
.
info
(
"** ** * Saving fine-tuned model ** ** * "
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
args
.
output_dir
/
"pytorch_model.bin"
torch
.
save
(
model_to_save
.
state_dict
(),
str
(
output_model_file
))
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
model_to_save
.
config
.
to_json_file
(
output_config_file
)
tokenizer
.
save_vocabulary
(
args
.
output_dir
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
examples/lm_finetuning/simple_lm_finetuning.py
View file @
cad88e19
...
@@ -29,6 +29,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
...
@@ -29,6 +29,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert
import
WEIGHTS_NAME
,
CONFIG_NAME
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
WarmupLinearSchedule
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
WarmupLinearSchedule
...
@@ -614,9 +615,12 @@ def main():
...
@@ -614,9 +615,12 @@ def main():
# Save a trained model
# Save a trained model
logger
.
info
(
"** ** * Saving fine - tuned model ** ** * "
)
logger
.
info
(
"** ** * Saving fine - tuned model ** ** * "
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
if
args
.
do_train
:
if
args
.
do_train
:
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
model_to_save
.
config
.
to_json_file
(
output_config_file
)
tokenizer
.
save_vocabulary
(
args
.
output_dir
)
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment