Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4e6a5575
Commit
4e6a5575
authored
Sep 13, 2019
by
Simon Layton
Browse files
Force einsum to fp16
parent
f62f992c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
examples/run_squad.py
examples/run_squad.py
+12
-2
No files found.
examples/run_squad.py
View file @
4e6a5575
...
...
@@ -138,8 +138,8 @@ def train(args, train_dataset, model, tokenizer):
model
.
train
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
],
'start_positions'
:
batch
[
3
],
'attention_mask'
:
batch
[
1
],
'start_positions'
:
batch
[
3
],
'end_positions'
:
batch
[
4
]}
if
args
.
model_type
!=
'distilbert'
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
...
...
@@ -481,6 +481,16 @@ def main():
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
# remove the need for this code, but it is still valid.
if
args
.
fp16
:
try
:
import
apex
apex
.
amp
.
register_half_function
(
torch
,
'einsum'
)
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
# Training
if
args
.
do_train
:
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment