Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a0a027c2
Unverified
Commit
a0a027c2
authored
Mar 16, 2021
by
Sylvain Gugger
Committed by
GitHub
Mar 16, 2021
Browse files
Add DistributedSamplerWithLoop (#10746)
* Add DistributedSamplerWithLoop * Fix typo * Test and small fix
parent
14492222
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
20 deletions
+93
-20
src/transformers/sagemaker/trainer_sm.py
src/transformers/sagemaker/trainer_sm.py
+8
-0
src/transformers/trainer.py
src/transformers/trainer.py
+16
-19
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+29
-1
src/transformers/training_args.py
src/transformers/training_args.py
+14
-0
tests/test_trainer_utils.py
tests/test_trainer_utils.py
+26
-0
No files found.
src/transformers/sagemaker/trainer_sm.py
View file @
a0a027c2
...
@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model
...
@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model
from
..trainer
import
Trainer
from
..trainer
import
Trainer
from
..trainer_pt_utils
import
(
from
..trainer_pt_utils
import
(
DistributedLengthGroupedSampler
,
DistributedLengthGroupedSampler
,
DistributedSamplerWithLoop
,
SequentialDistributedSampler
,
SequentialDistributedSampler
,
nested_detach
,
nested_detach
,
nested_numpify
,
nested_numpify
,
...
@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer):
...
@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer):
return
DistributedLengthGroupedSampler
(
return
DistributedLengthGroupedSampler
(
self
.
train_dataset
,
self
.
args
.
train_batch_size
,
num_replicas
=
smp
.
dp_size
(),
rank
=
smp
.
dp_rank
()
self
.
train_dataset
,
self
.
args
.
train_batch_size
,
num_replicas
=
smp
.
dp_size
(),
rank
=
smp
.
dp_rank
()
)
)
elif
not
self
.
args
.
dataloader_drop_last
:
return
DistributedSamplerWithLoop
(
self
.
train_dataset
,
self
.
args
.
per_device_train_batch_size
,
num_replicas
=
smp
.
dp_size
(),
rank
=
smp
.
dp_rank
(),
)
else
:
else
:
return
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
smp
.
dp_size
(),
rank
=
smp
.
dp_rank
())
return
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
smp
.
dp_size
(),
rank
=
smp
.
dp_rank
())
else
:
else
:
...
...
src/transformers/trainer.py
View file @
a0a027c2
...
@@ -77,6 +77,7 @@ from .trainer_callback import (
...
@@ -77,6 +77,7 @@ from .trainer_callback import (
)
)
from
.trainer_pt_utils
import
(
from
.trainer_pt_utils
import
(
DistributedLengthGroupedSampler
,
DistributedLengthGroupedSampler
,
DistributedSamplerWithLoop
,
DistributedTensorGatherer
,
DistributedTensorGatherer
,
LabelSmoother
,
LabelSmoother
,
LengthGroupedSampler
,
LengthGroupedSampler
,
...
@@ -491,24 +492,10 @@ class Trainer:
...
@@ -491,24 +492,10 @@ class Trainer:
):
):
return
None
return
None
# Gather the number of processes and this process index.
if
self
.
args
.
parallel_mode
==
ParallelMode
.
TPU
:
num_processes
=
xm
.
xrt_world_size
()
process_index
=
xm
.
get_ordinal
()
elif
(
self
.
args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
or
self
.
args
.
parallel_mode
==
ParallelMode
.
SAGEMAKER_DISTRIBUTED
):
num_processes
=
dist
.
get_world_size
()
process_index
=
dist
.
get_rank
()
else
:
num_processes
=
1
process_index
=
0
# Build the sampler.
# Build the sampler.
if
self
.
args
.
group_by_length
:
if
self
.
args
.
group_by_length
:
model_input_name
=
self
.
tokenizer
.
model_input_names
[
0
]
if
self
.
tokenizer
is
not
None
else
None
model_input_name
=
self
.
tokenizer
.
model_input_names
[
0
]
if
self
.
tokenizer
is
not
None
else
None
if
num_processes
<=
1
:
if
self
.
args
.
world_size
<=
1
:
return
LengthGroupedSampler
(
return
LengthGroupedSampler
(
self
.
train_dataset
,
self
.
args
.
train_batch_size
,
model_input_name
=
model_input_name
self
.
train_dataset
,
self
.
args
.
train_batch_size
,
model_input_name
=
model_input_name
)
)
...
@@ -516,16 +503,26 @@ class Trainer:
...
@@ -516,16 +503,26 @@ class Trainer:
return
DistributedLengthGroupedSampler
(
return
DistributedLengthGroupedSampler
(
self
.
train_dataset
,
self
.
train_dataset
,
self
.
args
.
train_batch_size
,
self
.
args
.
train_batch_size
,
num_replicas
=
num_processes
,
num_replicas
=
self
.
args
.
world_size
,
rank
=
process_index
,
rank
=
self
.
args
.
process_index
,
model_input_name
=
model_input_name
,
model_input_name
=
model_input_name
,
)
)
else
:
else
:
if
num_processes
<=
1
:
if
self
.
args
.
world_size
<=
1
:
return
RandomSampler
(
self
.
train_dataset
)
return
RandomSampler
(
self
.
train_dataset
)
elif
self
.
args
.
parallel_mode
==
ParallelMode
.
TPU
and
not
self
.
args
.
dataloader_drop_last
:
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return
DistributedSamplerWithLoop
(
self
.
train_dataset
,
batch_size
=
self
.
args
.
per_device_train_batch_size
,
num_replicas
=
self
.
args
.
world_size
,
rank
=
self
.
args
.
process_index
,
)
else
:
else
:
return
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
num_processes
,
rank
=
process_index
)
return
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
self
.
args
.
world_size
,
rank
=
self
.
args
.
process_index
)
def
get_train_dataloader
(
self
)
->
DataLoader
:
def
get_train_dataloader
(
self
)
->
DataLoader
:
"""
"""
...
...
src/transformers/trainer_pt_utils.py
View file @
a0a027c2
...
@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int):
...
@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int):
dist
.
barrier
()
dist
.
barrier
()
class
DistributedSamplerWithLoop
(
DistributedSampler
):
"""
Like a :obj:torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the
shuffled samples to make each process have a round multiple of batch_size samples.
Args:
dataset (:obj:`torch.utils.data.Dataset`):
Dataset used for sampling.
batch_size (:obj:`int`):
The batch size used with this sampler
kwargs:
All other keyword arguments passed to :obj:`DistributedSampler`.
"""
def
__init__
(
self
,
dataset
,
batch_size
,
**
kwargs
):
super
().
__init__
(
dataset
,
**
kwargs
)
self
.
batch_size
=
batch_size
def
__iter__
(
self
):
indices
=
list
(
super
().
__iter__
())
remainder
=
0
if
len
(
indices
)
%
self
.
batch_size
==
0
else
self
.
batch_size
-
len
(
indices
)
%
self
.
batch_size
# DistributedSampler already added samples from the beginning to make the number of samples a round multiple
# of the world size, so we skip those.
start_remainder
=
1
if
self
.
rank
<
len
(
self
.
dataset
)
%
self
.
num_replicas
else
0
indices
+=
indices
[
start_remainder
:
start_remainder
+
remainder
]
return
iter
(
indices
)
class
SequentialDistributedSampler
(
Sampler
):
class
SequentialDistributedSampler
(
Sampler
):
"""
"""
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
...
@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler):
...
@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler):
return
self
.
num_samples
return
self
.
num_samples
def
get_tpu_sampler
(
dataset
:
torch
.
utils
.
data
.
dataset
.
Dataset
):
def
get_tpu_sampler
(
dataset
:
torch
.
utils
.
data
.
dataset
.
Dataset
,
bach_size
:
int
):
if
xm
.
xrt_world_size
()
<=
1
:
if
xm
.
xrt_world_size
()
<=
1
:
return
RandomSampler
(
dataset
)
return
RandomSampler
(
dataset
)
return
DistributedSampler
(
dataset
,
num_replicas
=
xm
.
xrt_world_size
(),
rank
=
xm
.
get_ordinal
())
return
DistributedSampler
(
dataset
,
num_replicas
=
xm
.
xrt_world_size
(),
rank
=
xm
.
get_ordinal
())
...
...
src/transformers/training_args.py
View file @
a0a027c2
...
@@ -742,6 +742,20 @@ class TrainingArguments:
...
@@ -742,6 +742,20 @@ class TrainingArguments:
return
torch
.
distributed
.
get_world_size
()
return
torch
.
distributed
.
get_world_size
()
return
1
return
1
@
property
@
torch_required
def
process_index
(
self
):
"""
The number of processes used in parallel.
"""
if
is_torch_tpu_available
():
return
xm
.
get_ordinal
()
elif
is_sagemaker_distributed_available
():
return
sm_dist
.
get_rank
()
elif
self
.
local_rank
!=
-
1
:
return
torch
.
distributed
.
get_rank
()
return
0
@
property
@
property
def
place_model_on_device
(
self
):
def
place_model_on_device
(
self
):
"""
"""
...
...
tests/test_trainer_utils.py
View file @
a0a027c2
...
@@ -27,6 +27,7 @@ if is_torch_available():
...
@@ -27,6 +27,7 @@ if is_torch_available():
from
transformers.modeling_outputs
import
SequenceClassifierOutput
from
transformers.modeling_outputs
import
SequenceClassifierOutput
from
transformers.trainer_pt_utils
import
(
from
transformers.trainer_pt_utils
import
(
DistributedLengthGroupedSampler
,
DistributedLengthGroupedSampler
,
DistributedSamplerWithLoop
,
DistributedTensorGatherer
,
DistributedTensorGatherer
,
LabelSmoother
,
LabelSmoother
,
LengthGroupedSampler
,
LengthGroupedSampler
,
...
@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase):
...
@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase):
[
'0.linear1.weight'
,
'0.linear1.bias'
,
'0.linear2.weight'
,
'0.linear2.bias'
,
'0.bias'
,
'1.0.linear1.weight'
,
'1.0.linear1.bias'
,
'1.0.linear2.weight'
,
'1.0.linear2.bias'
,
'1.0.bias'
,
'1.1.linear1.weight'
,
'1.1.linear1.bias'
,
'1.1.linear2.weight'
,
'1.1.linear2.bias'
,
'1.1.bias'
]
[
'0.linear1.weight'
,
'0.linear1.bias'
,
'0.linear2.weight'
,
'0.linear2.bias'
,
'0.bias'
,
'1.0.linear1.weight'
,
'1.0.linear1.bias'
,
'1.0.linear2.weight'
,
'1.0.linear2.bias'
,
'1.0.bias'
,
'1.1.linear1.weight'
,
'1.1.linear1.bias'
,
'1.1.linear2.weight'
,
'1.1.linear2.bias'
,
'1.1.bias'
]
)
)
# fmt: on
# fmt: on
def
test_distributed_sampler_with_loop
(
self
):
batch_size
=
16
for
length
in
[
23
,
64
,
123
]:
dataset
=
list
(
range
(
length
))
shard1
=
DistributedSamplerWithLoop
(
dataset
,
batch_size
,
num_replicas
=
2
,
rank
=
0
)
shard2
=
DistributedSamplerWithLoop
(
dataset
,
batch_size
,
num_replicas
=
2
,
rank
=
1
)
# Set seeds
shard1
.
set_epoch
(
0
)
shard2
.
set_epoch
(
0
)
# Sample
samples1
=
list
(
shard1
)
samples2
=
list
(
shard2
)
self
.
assertTrue
(
len
(
samples1
)
%
batch_size
==
0
)
self
.
assertTrue
(
len
(
samples2
)
%
batch_size
==
0
)
total
=
[]
for
sample1
,
sample2
in
zip
(
samples1
,
samples2
):
total
+=
[
sample1
,
sample2
]
self
.
assertEqual
(
set
(
total
[:
length
]),
set
(
dataset
))
self
.
assertEqual
(
set
(
total
[
length
:]),
set
(
total
[:
(
len
(
total
)
-
length
)]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment