Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5c08c8c2
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "c8f7bb990c072c77d8cfc8a6c883ba0c352a5671"
Commit
5c08c8c2
authored
Jun 11, 2019
by
Oliver Guhr
Browse files
adds the tokenizer + model config to the output
parent
784c0ed8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
3 deletions
+14
-3
examples/lm_finetuning/finetune_on_pregenerated.py
examples/lm_finetuning/finetune_on_pregenerated.py
+9
-2
examples/lm_finetuning/simple_lm_finetuning.py
examples/lm_finetuning/simple_lm_finetuning.py
+5
-1
No files found.
examples/lm_finetuning/finetune_on_pregenerated.py
View file @
5c08c8c2
from
argparse
import
ArgumentParser
from
argparse
import
ArgumentParser
from
pathlib
import
Path
from
pathlib
import
Path
import
os
import
torch
import
torch
import
logging
import
logging
import
json
import
json
...
@@ -12,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
...
@@ -12,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
pytorch_pretrained_bert
import
WEIGHTS_NAME
,
CONFIG_NAME
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
WarmupLinearSchedule
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
WarmupLinearSchedule
...
@@ -325,8 +327,13 @@ def main():
...
@@ -325,8 +327,13 @@ def main():
# Save a trained model
# Save a trained model
logging
.
info
(
"** ** * Saving fine-tuned model ** ** * "
)
logging
.
info
(
"** ** * Saving fine-tuned model ** ** * "
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
args
.
output_dir
/
"pytorch_model.bin"
torch
.
save
(
model_to_save
.
state_dict
(),
str
(
output_model_file
))
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
model_to_save
.
config
.
to_json_file
(
output_config_file
)
tokenizer
.
save_vocabulary
(
args
.
output_dir
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
examples/lm_finetuning/simple_lm_finetuning.py
View file @
5c08c8c2
...
@@ -29,6 +29,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
...
@@ -29,6 +29,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
from
pytorch_pretrained_bert
import
WEIGHTS_NAME
,
CONFIG_NAME
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
WarmupLinearSchedule
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
WarmupLinearSchedule
...
@@ -614,9 +615,12 @@ def main():
...
@@ -614,9 +615,12 @@ def main():
# Save a trained model
# Save a trained model
logger
.
info
(
"** ** * Saving fine - tuned model ** ** * "
)
logger
.
info
(
"** ** * Saving fine - tuned model ** ** * "
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
if
args
.
do_train
:
if
args
.
do_train
:
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
model_to_save
.
config
.
to_json_file
(
output_config_file
)
tokenizer
.
save_vocabulary
(
args
.
output_dir
)
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment