Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ba4bce25
Unverified
Commit
ba4bce25
authored
Aug 13, 2019
by
tuvuumass
Committed by
GitHub
Aug 13, 2019
Browse files
fix issue #824
parent
a7b4cfe9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
8 deletions
+10
-8
examples/run_bertology.py
examples/run_bertology.py
+10
-8
No files found.
examples/run_bertology.py
View file @
ba4bce25
...
...
@@ -211,10 +211,12 @@ def prune_heads(args, model, eval_dataloader, head_mask):
def
main
():
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the .tsv files (or other data files) for the task."
)
parser
.
add_argument
(
"--model_name"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Bert/XLNet/XLM pre-trained model selected in the list: "
+
", "
.
join
(
ALL_MODELS
))
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
ALL_MODELS
))
parser
.
add_argument
(
"--task_name"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The name of the task to train selected in the list: "
+
", "
.
join
(
processors
.
keys
()))
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
...
...
@@ -222,9 +224,9 @@ def main():
## Other parameters
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
help
=
"Pretrained config name or path if not the same as model_name
_or_path
"
)
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained tokenizer name or path if not the same as model_name"
)
help
=
"Pretrained tokenizer name or path if not the same as model_name
_or_path
"
)
parser
.
add_argument
(
"--cache_dir"
,
default
=
""
,
type
=
str
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
)
parser
.
add_argument
(
"--data_subset"
,
type
=
int
,
default
=-
1
,
...
...
@@ -297,15 +299,15 @@ def main():
args
.
model_type
=
""
for
key
in
MODEL_CLASSES
:
if
key
in
args
.
model_name
.
lower
():
if
key
in
args
.
model_name
_or_path
.
lower
():
args
.
model_type
=
key
# take the first match in model types
break
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
config
=
config_class
.
from_pretrained
(
args
.
config_name
if
args
.
config_name
else
args
.
model_name
,
config
=
config_class
.
from_pretrained
(
args
.
config_name
if
args
.
config_name
else
args
.
model_name
_or_path
,
num_labels
=
num_labels
,
finetuning_task
=
args
.
task_name
,
output_attentions
=
True
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name
)
model
=
model_class
.
from_pretrained
(
args
.
model_name
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name
),
config
=
config
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name
_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name
_or_path
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name
_or_path
),
config
=
config
)
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment