Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4428aefc
Unverified
Commit
4428aefc
authored
Oct 11, 2019
by
Thomas Wolf
Committed by
GitHub
Oct 11, 2019
Browse files
Merge pull request #1488 from huggingface/pytorch-tpu
GLUE on TPU
parents
3b43b018
639f4b71
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
2 deletions
+33
-2
examples/run_glue.py
examples/run_glue.py
+33
-2
No files found.
examples/run_glue.py
View file @
4428aefc
...
@@ -160,7 +160,7 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -160,7 +160,7 @@ def train(args, train_dataset, model, tokenizer):
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
tr_loss
+=
loss
.
item
()
tr_loss
+=
loss
.
item
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
and
not
args
.
tpu
:
optimizer
.
step
()
optimizer
.
step
()
scheduler
.
step
()
# Update learning rate schedule
scheduler
.
step
()
# Update learning rate schedule
model
.
zero_grad
()
model
.
zero_grad
()
...
@@ -186,6 +186,11 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -186,6 +186,11 @@ def train(args, train_dataset, model, tokenizer):
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
'training_args.bin'
))
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
'training_args.bin'
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
if
args
.
tpu
:
args
.
xla_model
.
optimizer_step
(
optimizer
,
barrier
=
True
)
model
.
zero_grad
()
global_step
+=
1
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
epoch_iterator
.
close
()
epoch_iterator
.
close
()
break
break
...
@@ -385,6 +390,15 @@ def main():
...
@@ -385,6 +390,15 @@ def main():
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
42
,
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
help
=
"random seed for initialization"
)
parser
.
add_argument
(
'--tpu'
,
action
=
'store_true'
,
help
=
"Whether to run on the TPU defined in the environment variables"
)
parser
.
add_argument
(
'--tpu_ip_address'
,
type
=
str
,
default
=
''
,
help
=
"TPU IP address if none are set in the environment variables"
)
parser
.
add_argument
(
'--tpu_name'
,
type
=
str
,
default
=
''
,
help
=
"TPU name if none are set in the environment variables"
)
parser
.
add_argument
(
'--xrt_tpu_config'
,
type
=
str
,
default
=
''
,
help
=
"XRT TPU config if none are set in the environment variables"
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
"Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
)
help
=
"Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
)
parser
.
add_argument
(
'--fp16_opt_level'
,
type
=
str
,
default
=
'O1'
,
parser
.
add_argument
(
'--fp16_opt_level'
,
type
=
str
,
default
=
'O1'
,
...
@@ -418,6 +432,23 @@ def main():
...
@@ -418,6 +432,23 @@ def main():
args
.
n_gpu
=
1
args
.
n_gpu
=
1
args
.
device
=
device
args
.
device
=
device
if
args
.
tpu
:
if
args
.
tpu_ip_address
:
os
.
environ
[
"TPU_IP_ADDRESS"
]
=
args
.
tpu_ip_address
if
args
.
tpu_name
:
os
.
environ
[
"TPU_NAME"
]
=
args
.
tpu_name
if
args
.
xrt_tpu_config
:
os
.
environ
[
"XRT_TPU_CONFIG"
]
=
args
.
xrt_tpu_config
assert
"TPU_IP_ADDRESS"
in
os
.
environ
assert
"TPU_NAME"
in
os
.
environ
assert
"XRT_TPU_CONFIG"
in
os
.
environ
import
torch_xla
import
torch_xla.core.xla_model
as
xm
args
.
device
=
xm
.
xla_device
()
args
.
xla_model
=
xm
# Setup logging
# Setup logging
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -463,7 +494,7 @@ def main():
...
@@ -463,7 +494,7 @@ def main():
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if
args
.
do_train
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
if
args
.
do_train
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
)
and
not
args
.
tpu
:
# Create output directory if needed
# Create output directory if needed
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment