Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a0a027c2
Unverified
Commit
a0a027c2
authored
Mar 16, 2021
by
Sylvain Gugger
Committed by
GitHub
Mar 16, 2021
Browse files
Add DistributedSamplerWithLoop (#10746)
* Add DistributedSamplerWithLoop * Fix typo * Test and small fix
parent
14492222
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
20 deletions
+93
-20
src/transformers/sagemaker/trainer_sm.py
src/transformers/sagemaker/trainer_sm.py
+8
-0
src/transformers/trainer.py
src/transformers/trainer.py
+16
-19
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+29
-1
src/transformers/training_args.py
src/transformers/training_args.py
+14
-0
tests/test_trainer_utils.py
tests/test_trainer_utils.py
+26
-0
No files found.
src/transformers/sagemaker/trainer_sm.py
View file @
a0a027c2
...
...
@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model
from
..trainer
import
Trainer
from
..trainer_pt_utils
import
(
DistributedLengthGroupedSampler
,
DistributedSamplerWithLoop
,
SequentialDistributedSampler
,
nested_detach
,
nested_numpify
,
...
...
@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer):
return
DistributedLengthGroupedSampler
(
self
.
train_dataset
,
self
.
args
.
train_batch_size
,
num_replicas
=
smp
.
dp_size
(),
rank
=
smp
.
dp_rank
()
)
elif
not
self
.
args
.
dataloader_drop_last
:
return
DistributedSamplerWithLoop
(
self
.
train_dataset
,
self
.
args
.
per_device_train_batch_size
,
num_replicas
=
smp
.
dp_size
(),
rank
=
smp
.
dp_rank
(),
)
else
:
return
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
smp
.
dp_size
(),
rank
=
smp
.
dp_rank
())
else
:
...
...
src/transformers/trainer.py
View file @
a0a027c2
...
...
@@ -77,6 +77,7 @@ from .trainer_callback import (
)
from
.trainer_pt_utils
import
(
DistributedLengthGroupedSampler
,
DistributedSamplerWithLoop
,
DistributedTensorGatherer
,
LabelSmoother
,
LengthGroupedSampler
,
...
...
@@ -491,24 +492,10 @@ class Trainer:
):
return
None
# Gather the number of processes and this process index.
if
self
.
args
.
parallel_mode
==
ParallelMode
.
TPU
:
num_processes
=
xm
.
xrt_world_size
()
process_index
=
xm
.
get_ordinal
()
elif
(
self
.
args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
or
self
.
args
.
parallel_mode
==
ParallelMode
.
SAGEMAKER_DISTRIBUTED
):
num_processes
=
dist
.
get_world_size
()
process_index
=
dist
.
get_rank
()
else
:
num_processes
=
1
process_index
=
0
# Build the sampler.
if
self
.
args
.
group_by_length
:
model_input_name
=
self
.
tokenizer
.
model_input_names
[
0
]
if
self
.
tokenizer
is
not
None
else
None
if
num_processes
<=
1
:
if
self
.
args
.
world_size
<=
1
:
return
LengthGroupedSampler
(
self
.
train_dataset
,
self
.
args
.
train_batch_size
,
model_input_name
=
model_input_name
)
...
...
@@ -516,16 +503,26 @@ class Trainer:
return
DistributedLengthGroupedSampler
(
self
.
train_dataset
,
self
.
args
.
train_batch_size
,
num_replicas
=
num_processes
,
rank
=
process_index
,
num_replicas
=
self
.
args
.
world_size
,
rank
=
self
.
args
.
process_index
,
model_input_name
=
model_input_name
,
)
else
:
if
num_processes
<=
1
:
if
self
.
args
.
world_size
<=
1
:
return
RandomSampler
(
self
.
train_dataset
)
elif
self
.
args
.
parallel_mode
==
ParallelMode
.
TPU
and
not
self
.
args
.
dataloader_drop_last
:
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return
DistributedSamplerWithLoop
(
self
.
train_dataset
,
batch_size
=
self
.
args
.
per_device_train_batch_size
,
num_replicas
=
self
.
args
.
world_size
,
rank
=
self
.
args
.
process_index
,
)
else
:
return
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
num_processes
,
rank
=
process_index
)
return
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
self
.
args
.
world_size
,
rank
=
self
.
args
.
process_index
)
def
get_train_dataloader
(
self
)
->
DataLoader
:
"""
...
...
src/transformers/trainer_pt_utils.py
View file @
a0a027c2
...
...
@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int):
dist
.
barrier
()
class
DistributedSamplerWithLoop
(
DistributedSampler
):
"""
Like a :obj:torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the
shuffled samples to make each process have a round multiple of batch_size samples.
Args:
dataset (:obj:`torch.utils.data.Dataset`):
Dataset used for sampling.
batch_size (:obj:`int`):
The batch size used with this sampler
kwargs:
All other keyword arguments passed to :obj:`DistributedSampler`.
"""
def
__init__
(
self
,
dataset
,
batch_size
,
**
kwargs
):
super
().
__init__
(
dataset
,
**
kwargs
)
self
.
batch_size
=
batch_size
def
__iter__
(
self
):
indices
=
list
(
super
().
__iter__
())
remainder
=
0
if
len
(
indices
)
%
self
.
batch_size
==
0
else
self
.
batch_size
-
len
(
indices
)
%
self
.
batch_size
# DistributedSampler already added samples from the beginning to make the number of samples a round multiple
# of the world size, so we skip those.
start_remainder
=
1
if
self
.
rank
<
len
(
self
.
dataset
)
%
self
.
num_replicas
else
0
indices
+=
indices
[
start_remainder
:
start_remainder
+
remainder
]
return
iter
(
indices
)
class
SequentialDistributedSampler
(
Sampler
):
"""
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
...
...
@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler):
return
self
.
num_samples
def
get_tpu_sampler
(
dataset
:
torch
.
utils
.
data
.
dataset
.
Dataset
):
def
get_tpu_sampler
(
dataset
:
torch
.
utils
.
data
.
dataset
.
Dataset
,
bach_size
:
int
):
if
xm
.
xrt_world_size
()
<=
1
:
return
RandomSampler
(
dataset
)
return
DistributedSampler
(
dataset
,
num_replicas
=
xm
.
xrt_world_size
(),
rank
=
xm
.
get_ordinal
())
...
...
src/transformers/training_args.py
View file @
a0a027c2
...
...
@@ -742,6 +742,20 @@ class TrainingArguments:
return
torch
.
distributed
.
get_world_size
()
return
1
@
property
@
torch_required
def
process_index
(
self
):
"""
The number of processes used in parallel.
"""
if
is_torch_tpu_available
():
return
xm
.
get_ordinal
()
elif
is_sagemaker_distributed_available
():
return
sm_dist
.
get_rank
()
elif
self
.
local_rank
!=
-
1
:
return
torch
.
distributed
.
get_rank
()
return
0
@
property
def
place_model_on_device
(
self
):
"""
...
...
tests/test_trainer_utils.py
View file @
a0a027c2
...
...
@@ -27,6 +27,7 @@ if is_torch_available():
from
transformers.modeling_outputs
import
SequenceClassifierOutput
from
transformers.trainer_pt_utils
import
(
DistributedLengthGroupedSampler
,
DistributedSamplerWithLoop
,
DistributedTensorGatherer
,
LabelSmoother
,
LengthGroupedSampler
,
...
...
@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase):
[
'0.linear1.weight'
,
'0.linear1.bias'
,
'0.linear2.weight'
,
'0.linear2.bias'
,
'0.bias'
,
'1.0.linear1.weight'
,
'1.0.linear1.bias'
,
'1.0.linear2.weight'
,
'1.0.linear2.bias'
,
'1.0.bias'
,
'1.1.linear1.weight'
,
'1.1.linear1.bias'
,
'1.1.linear2.weight'
,
'1.1.linear2.bias'
,
'1.1.bias'
]
)
# fmt: on
def
test_distributed_sampler_with_loop
(
self
):
batch_size
=
16
for
length
in
[
23
,
64
,
123
]:
dataset
=
list
(
range
(
length
))
shard1
=
DistributedSamplerWithLoop
(
dataset
,
batch_size
,
num_replicas
=
2
,
rank
=
0
)
shard2
=
DistributedSamplerWithLoop
(
dataset
,
batch_size
,
num_replicas
=
2
,
rank
=
1
)
# Set seeds
shard1
.
set_epoch
(
0
)
shard2
.
set_epoch
(
0
)
# Sample
samples1
=
list
(
shard1
)
samples2
=
list
(
shard2
)
self
.
assertTrue
(
len
(
samples1
)
%
batch_size
==
0
)
self
.
assertTrue
(
len
(
samples2
)
%
batch_size
==
0
)
total
=
[]
for
sample1
,
sample2
in
zip
(
samples1
,
samples2
):
total
+=
[
sample1
,
sample2
]
self
.
assertEqual
(
set
(
total
[:
length
]),
set
(
dataset
))
self
.
assertEqual
(
set
(
total
[
length
:]),
set
(
total
[:
(
len
(
total
)
-
length
)]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment