Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4e46affc
Commit
4e46affc
authored
Nov 17, 2018
by
thomwolf
Browse files
updating examples
parent
d0673c7d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
70 deletions
+19
-70
examples/extract_features.py
examples/extract_features.py
+5
-17
examples/run_classifier.py
examples/run_classifier.py
+5
-26
examples/run_squad.py
examples/run_squad.py
+5
-23
setup.py
setup.py
+4
-4
No files found.
examples/extract_features.py
View file @
4e46affc
...
...
@@ -193,23 +193,16 @@ def main():
## Required parameters
parser
.
add_argument
(
"--input_file"
,
default
=
None
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--vocab_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The vocabulary file that the BERT model was trained on."
)
parser
.
add_argument
(
"--output_file"
,
default
=
None
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture."
)
parser
.
add_argument
(
"--init_checkpoint"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
parser
.
add_argument
(
"--bert_model"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
## Other parameters
parser
.
add_argument
(
"--layers"
,
default
=
"-1,-2,-3,-4"
,
type
=
str
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
help
=
"The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded."
)
parser
.
add_argument
(
"--do_lower_case"
,
default
=
True
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models."
)
parser
.
add_argument
(
"--batch_size"
,
default
=
32
,
type
=
int
,
help
=
"Batch size for predictions."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
...
...
@@ -230,10 +223,7 @@ def main():
layer_indexes
=
[
int
(
x
)
for
x
in
args
.
layers
.
split
(
","
)]
bert_config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
tokenizer
=
BertTokenizer
(
vocab_file
=
args
.
vocab_file
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
)
examples
=
read_examples
(
args
.
input_file
)
...
...
@@ -244,9 +234,7 @@ def main():
for
feature
in
features
:
unique_id_to_feature
[
feature
.
unique_id
]
=
feature
model
=
BertModel
(
bert_config
)
if
args
.
init_checkpoint
is
not
None
:
model
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
=
BertModel
.
from_pretrained
(
args
.
bert_model
)
model
.
to
(
device
)
if
args
.
local_rank
!=
-
1
:
...
...
examples/run_classifier.py
View file @
4e46affc
...
...
@@ -343,12 +343,9 @@ def main():
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the .tsv files (or other data files) for the task."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--bert_model"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
parser
.
add_argument
(
"--task_name"
,
default
=
None
,
type
=
str
,
...
...
@@ -366,14 +363,6 @@ def main():
help
=
"The output directory where the model checkpoints will be written."
)
## Other parameters
parser
.
add_argument
(
"--init_checkpoint"
,
default
=
None
,
type
=
str
,
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
parser
.
add_argument
(
"--do_lower_case"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. True for uncased models, False for cased models."
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
...
...
@@ -477,13 +466,6 @@ def main():
if
not
args
.
do_train
and
not
args
.
do_eval
:
raise
ValueError
(
"At least one of `do_train` or `do_eval` must be True."
)
bert_config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
if
args
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
raise
ValueError
(
"Cannot use sequence length {} because the BERT model was only trained up to sequence length {}"
.
format
(
args
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
...
...
@@ -496,8 +478,7 @@ def main():
processor
=
processors
[
task_name
]()
label_list
=
processor
.
get_labels
()
tokenizer
=
BertTokenizer
(
vocab_file
=
args
.
vocab_file
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
)
train_examples
=
None
num_train_steps
=
None
...
...
@@ -507,9 +488,7 @@ def main():
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
# Prepare model
model
=
BertForSequenceClassification
(
bert_config
,
len
(
label_list
))
if
args
.
init_checkpoint
is
not
None
:
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
len
(
label_list
))
if
args
.
fp16
:
model
.
half
()
model
.
to
(
device
)
...
...
examples/run_squad.py
View file @
4e46affc
...
...
@@ -699,11 +699,9 @@ def main():
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture."
)
parser
.
add_argument
(
"--vocab_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The vocabulary file that the BERT model was trained on."
)
parser
.
add_argument
(
"--bert_model"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model checkpoints will be written."
)
...
...
@@ -711,11 +709,6 @@ def main():
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"SQuAD json for training. E.g., train-v1.1.json"
)
parser
.
add_argument
(
"--predict_file"
,
default
=
None
,
type
=
str
,
help
=
"SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
)
parser
.
add_argument
(
"--init_checkpoint"
,
default
=
None
,
type
=
str
,
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
parser
.
add_argument
(
"--do_lower_case"
,
default
=
True
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models."
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
384
,
type
=
int
,
help
=
"The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded."
)
...
...
@@ -815,20 +808,11 @@ def main():
raise
ValueError
(
"If `do_predict` is True, then `predict_file` must be specified."
)
bert_config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
if
args
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
raise
ValueError
(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d"
%
(
args
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory () already exists and is not empty."
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
tokenizer
=
BertTokenizer
(
vocab_file
=
args
.
vocab_file
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
)
train_examples
=
None
num_train_steps
=
None
...
...
@@ -839,9 +823,7 @@ def main():
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
# Prepare model
model
=
BertForQuestionAnswering
(
bert_config
)
if
args
.
init_checkpoint
is
not
None
:
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
)
if
args
.
fp16
:
model
.
half
()
model
.
to
(
device
)
...
...
setup.py
View file @
4e46affc
...
...
@@ -13,11 +13,11 @@ setup(
url
=
"https://github.com/huggingface/pytorch-pretrained-BERT"
,
packages
=
find_packages
(
exclude
=
[
"*.tests"
,
"*.tests.*"
,
"tests.*"
,
"tests"
]),
install_requires
=
[
'
numpy
'
,
'
torch>=0.4.1
'
,
install_requires
=
[
'
torch>=0.4.1
'
,
'
numpy
'
,
'boto3'
,
'requests
>=2.18
'
,
'tqdm
>=4.19
'
],
'requests'
,
'tqdm'
],
scripts
=
[
"bin/pytorch_pretrained_bert"
],
python_requires
=
'>=3.5.0'
,
tests_require
=
[
'pytest'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment