Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e1bea4f
Unverified
Commit
5e1bea4f
authored
Jan 14, 2021
by
Sylvain Gugger
Committed by
GitHub
Jan 14, 2021
Browse files
Fix Trainer with a parallel model (#9578)
* Fix Trainer with a parallel model * More clean up
parent
126fd281
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
13 deletions
+14
-13
src/transformers/training_args.py
src/transformers/training_args.py
+11
-12
tests/test_trainer.py
tests/test_trainer.py
+3
-1
No files found.
src/transformers/training_args.py
View file @
5e1bea4f
...
...
@@ -16,7 +16,7 @@ import json
import
os
from
dataclasses
import
asdict
,
dataclass
,
field
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
from
.file_utils
import
cached_property
,
is_torch_available
,
is_torch_tpu_available
,
torch_required
from
.trainer_utils
import
EvaluationStrategy
,
SchedulerType
...
...
@@ -426,7 +426,6 @@ class TrainingArguments:
if
is_torch_available
()
and
self
.
device
.
type
!=
"cuda"
and
self
.
fp16
:
raise
ValueError
(
"Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices."
)
self
.
_n_gpu
=
torch
.
cuda
.
device_count
()
def
__repr__
(
self
):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
...
...
@@ -467,14 +466,14 @@ class TrainingArguments:
@
cached_property
@
torch_required
def
_setup_devices
(
self
)
->
Tuple
[
"torch.device"
,
int
]
:
def
_setup_devices
(
self
)
->
"torch.device"
:
logger
.
info
(
"PyTorch: setting up devices"
)
if
self
.
no_cuda
:
device
=
torch
.
device
(
"cpu"
)
n_gpu
=
0
self
.
_
n_gpu
=
0
elif
is_torch_tpu_available
():
device
=
xm
.
xla_device
()
n_gpu
=
0
self
.
_
n_gpu
=
0
elif
self
.
local_rank
==
-
1
:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
...
...
@@ -485,9 +484,7 @@ class TrainingArguments:
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
if
self
.
_n_gpu
==
-
1
:
self
.
_n_gpu
=
torch
.
cuda
.
device_count
()
n_gpu
=
self
.
_n_gpu
self
.
_n_gpu
=
torch
.
cuda
.
device_count
()
else
:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
...
...
@@ -507,12 +504,12 @@ class TrainingArguments:
else
:
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
device
=
torch
.
device
(
"cuda"
,
self
.
local_rank
)
n_gpu
=
1
self
.
_
n_gpu
=
1
if
device
.
type
==
"cuda"
:
torch
.
cuda
.
set_device
(
device
)
return
device
,
n_gpu
return
device
@
property
@
torch_required
...
...
@@ -520,7 +517,7 @@ class TrainingArguments:
"""
The device used by this process.
"""
return
self
.
_setup_devices
[
0
]
return
self
.
_setup_devices
@
property
@
torch_required
...
...
@@ -532,7 +529,9 @@ class TrainingArguments:
This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1.
"""
return
self
.
_setup_devices
[
1
]
# Make sure `self._n_gpu` is properly setup.
_
=
self
.
_setup_devices
return
self
.
_n_gpu
@
property
@
torch_required
...
...
tests/test_trainer.py
View file @
5e1bea4f
...
...
@@ -381,9 +381,11 @@ class TrainerIntegrationTest(unittest.TestCase):
# Make the Trainer believe it's a parallelized model
model
.
is_parallelizable
=
True
model
.
model_parallel
=
True
trainer
=
Trainer
(
model
=
model
,
train_dataset
=
RegressionDataset
(),
eval_dataset
=
RegressionDataset
())
args
=
TrainingArguments
(
"./regression"
,
per_device_train_batch_size
=
16
,
per_device_eval_batch_size
=
16
)
trainer
=
Trainer
(
model
,
args
,
train_dataset
=
RegressionDataset
(),
eval_dataset
=
RegressionDataset
())
# Check the Trainer was fooled
self
.
assertTrue
(
trainer
.
is_model_parallel
)
self
.
assertEqual
(
trainer
.
args
.
n_gpu
,
1
)
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
self
.
assertEqual
(
trainer
.
get_train_dataloader
().
batch_size
,
16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment