Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
16ce15ed
Commit
16ce15ed
authored
Jan 08, 2020
by
Lysandre
Browse files
DistilBERT token type ids removed from inputs in run_squad
parent
f24232cd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
examples/run_squad.py
examples/run_squad.py
+18
-7
No files found.
examples/run_squad.py
View file @
16ce15ed
...
...
@@ -207,11 +207,14 @@ def train(args, train_dataset, model, tokenizer):
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"token_type_ids"
:
None
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
]
else
batch
[
2
],
"token_type_ids"
:
batch
[
2
],
"start_positions"
:
batch
[
3
],
"end_positions"
:
batch
[
4
],
}
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
]:
del
inputs
[
"token_type_ids"
]
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
inputs
.
update
({
"cls_index"
:
batch
[
5
],
"p_mask"
:
batch
[
6
]})
if
args
.
version_2_with_negative
:
...
...
@@ -316,8 +319,12 @@ def evaluate(args, model, tokenizer, prefix=""):
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"token_type_ids"
:
None
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
]
else
batch
[
2
],
"token_type_ids"
:
batch
[
2
],
}
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
]:
del
inputs
[
"token_type_ids"
]
example_indices
=
batch
[
3
]
# XLNet and XLM use more arguments for their predictions
...
...
@@ -427,10 +434,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
)
# Init features and dataset from cache if it exists
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
and
not
output_examples
:
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
:
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
features_and_dataset
=
torch
.
load
(
cached_features_file
)
features
,
dataset
=
features_and_dataset
[
"features"
],
features_and_dataset
[
"dataset"
]
features
,
dataset
,
examples
=
(
features_and_dataset
[
"features"
],
features_and_dataset
[
"dataset"
],
features_and_dataset
[
"examples"
],
)
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
input_dir
)
...
...
@@ -465,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
({
"features"
:
features
,
"dataset"
:
dataset
},
cached_features_file
)
torch
.
save
({
"features"
:
features
,
"dataset"
:
dataset
,
"examples"
:
examples
},
cached_features_file
)
if
args
.
local_rank
==
0
and
not
evaluate
:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
...
...
@@ -776,7 +787,7 @@ def main():
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
"training_args.bin"
))
# Load a trained model and vocabulary that you have fine-tuned
model
=
model_class
.
from_pretrained
(
args
.
output_dir
,
force_download
=
True
)
model
=
model_class
.
from_pretrained
(
args
.
output_dir
)
#
, force_download=True)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
model
.
to
(
args
.
device
)
...
...
@@ -801,7 +812,7 @@ def main():
for
checkpoint
in
checkpoints
:
# Reload the model
global_step
=
checkpoint
.
split
(
"-"
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
model
=
model_class
.
from_pretrained
(
checkpoint
,
force_download
=
True
)
model
=
model_class
.
from_pretrained
(
checkpoint
)
#
, force_download=True)
model
.
to
(
args
.
device
)
# Evaluate
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment