Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4e46affc
Commit
4e46affc
authored
Nov 17, 2018
by
thomwolf
Browse files
updating examples
parent
d0673c7d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
70 deletions
+19
-70
examples/extract_features.py
examples/extract_features.py
+5
-17
examples/run_classifier.py
examples/run_classifier.py
+5
-26
examples/run_squad.py
examples/run_squad.py
+5
-23
setup.py
setup.py
+4
-4
No files found.
examples/extract_features.py
View file @
4e46affc
...
...
@@ -193,23 +193,16 @@ def main():
## Required parameters
parser
.
add_argument
(
"--input_file"
,
default
=
None
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--vocab_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The vocabulary file that the BERT model was trained on."
)
parser
.
add_argument
(
"--output_file"
,
default
=
None
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture."
)
parser
.
add_argument
(
"--init_checkpoint"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
parser
.
add_argument
(
"--bert_model"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
## Other parameters
parser
.
add_argument
(
"--layers"
,
default
=
"-1,-2,-3,-4"
,
type
=
str
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
help
=
"The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded."
)
parser
.
add_argument
(
"--do_lower_case"
,
default
=
True
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models."
)
parser
.
add_argument
(
"--batch_size"
,
default
=
32
,
type
=
int
,
help
=
"Batch size for predictions."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
...
...
@@ -230,10 +223,7 @@ def main():
layer_indexes
=
[
int
(
x
)
for
x
in
args
.
layers
.
split
(
","
)]
bert_config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
tokenizer
=
BertTokenizer
(
vocab_file
=
args
.
vocab_file
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
)
examples
=
read_examples
(
args
.
input_file
)
...
...
@@ -244,9 +234,7 @@ def main():
for
feature
in
features
:
unique_id_to_feature
[
feature
.
unique_id
]
=
feature
model
=
BertModel
(
bert_config
)
if
args
.
init_checkpoint
is
not
None
:
model
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
=
BertModel
.
from_pretrained
(
args
.
bert_model
)
model
.
to
(
device
)
if
args
.
local_rank
!=
-
1
:
...
...
examples/run_classifier.py
View file @
4e46affc
...
...
@@ -343,12 +343,9 @@ def main():
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the .tsv files (or other data files) for the task."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--bert_model"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
parser
.
add_argument
(
"--task_name"
,
default
=
None
,
type
=
str
,
...
...
@@ -366,14 +363,6 @@ def main():
help
=
"The output directory where the model checkpoints will be written."
)
## Other parameters
parser
.
add_argument
(
"--init_checkpoint"
,
default
=
None
,
type
=
str
,
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
parser
.
add_argument
(
"--do_lower_case"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. True for uncased models, False for cased models."
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
...
...
@@ -477,13 +466,6 @@ def main():
if
not
args
.
do_train
and
not
args
.
do_eval
:
raise
ValueError
(
"At least one of `do_train` or `do_eval` must be True."
)
bert_config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
if
args
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
raise
ValueError
(
"Cannot use sequence length {} because the BERT model was only trained up to sequence length {}"
.
format
(
args
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
...
...
@@ -496,8 +478,7 @@ def main():
processor
=
processors
[
task_name
]()
label_list
=
processor
.
get_labels
()
tokenizer
=
BertTokenizer
(
vocab_file
=
args
.
vocab_file
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
)
train_examples
=
None
num_train_steps
=
None
...
...
@@ -507,9 +488,7 @@ def main():
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
# Prepare model
model
=
BertForSequenceClassification
(
bert_config
,
len
(
label_list
))
if
args
.
init_checkpoint
is
not
None
:
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
len
(
label_list
))
if
args
.
fp16
:
model
.
half
()
model
.
to
(
device
)
...
...
examples/run_squad.py
View file @
4e46affc
...
...
@@ -699,11 +699,9 @@ def main():
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture."
)
parser
.
add_argument
(
"--vocab_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The vocabulary file that the BERT model was trained on."
)
parser
.
add_argument
(
"--bert_model"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model checkpoints will be written."
)
...
...
@@ -711,11 +709,6 @@ def main():
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"SQuAD json for training. E.g., train-v1.1.json"
)
parser
.
add_argument
(
"--predict_file"
,
default
=
None
,
type
=
str
,
help
=
"SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
)
parser
.
add_argument
(
"--init_checkpoint"
,
default
=
None
,
type
=
str
,
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
parser
.
add_argument
(
"--do_lower_case"
,
default
=
True
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models."
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
384
,
type
=
int
,
help
=
"The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded."
)
...
...
@@ -815,20 +808,11 @@ def main():
raise
ValueError
(
"If `do_predict` is True, then `predict_file` must be specified."
)
bert_config
=
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
if
args
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
raise
ValueError
(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d"
%
(
args
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory () already exists and is not empty."
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
tokenizer
=
BertTokenizer
(
vocab_file
=
args
.
vocab_file
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
)
train_examples
=
None
num_train_steps
=
None
...
...
@@ -839,9 +823,7 @@ def main():
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
# Prepare model
model
=
BertForQuestionAnswering
(
bert_config
)
if
args
.
init_checkpoint
is
not
None
:
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
)
if
args
.
fp16
:
model
.
half
()
model
.
to
(
device
)
...
...
setup.py
View file @
4e46affc
...
...
@@ -13,11 +13,11 @@ setup(
url
=
"https://github.com/huggingface/pytorch-pretrained-BERT"
,
packages
=
find_packages
(
exclude
=
[
"*.tests"
,
"*.tests.*"
,
"tests.*"
,
"tests"
]),
install_requires
=
[
'
numpy
'
,
'
torch>=0.4.1
'
,
install_requires
=
[
'
torch>=0.4.1
'
,
'
numpy
'
,
'boto3'
,
'requests
>=2.18
'
,
'tqdm
>=4.19
'
],
'requests'
,
'tqdm'
],
scripts
=
[
"bin/pytorch_pretrained_bert"
],
python_requires
=
'>=3.5.0'
,
tests_require
=
[
'pytest'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment