Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
77950c48
Unverified
Commit
77950c48
authored
Sep 10, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 10, 2020
Browse files
[wip/s2s] DistributedSortishSampler (#7056)
parent
51448673
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
24 deletions
+79
-24
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+3
-6
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+2
-2
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+74
-16
No files found.
examples/seq2seq/finetune.py
View file @
77950c48
...
@@ -3,7 +3,6 @@ import glob
...
@@ -3,7 +3,6 @@ import glob
import
logging
import
logging
import
os
import
os
import
time
import
time
import
warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
...
@@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer):
...
@@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer):
default_val_metric
=
"rouge2"
default_val_metric
=
"rouge2"
def
__init__
(
self
,
hparams
,
**
kwargs
):
def
__init__
(
self
,
hparams
,
**
kwargs
):
if
hparams
.
sortish_sampler
and
hparams
.
gpus
>
1
:
hparams
.
replace_sampler_ddp
=
False
super
().
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
,
**
kwargs
)
super
().
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
,
**
kwargs
)
use_task_specific_params
(
self
.
model
,
"summarization"
)
use_task_specific_params
(
self
.
model
,
"summarization"
)
save_git_info
(
self
.
hparams
.
output_dir
)
save_git_info
(
self
.
hparams
.
output_dir
)
...
@@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer):
...
@@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer):
"val"
:
self
.
hparams
.
val_max_target_length
,
"val"
:
self
.
hparams
.
val_max_target_length
,
"test"
:
self
.
hparams
.
test_max_target_length
,
"test"
:
self
.
hparams
.
test_max_target_length
,
}
}
if
self
.
hparams
.
sortish_sampler
and
self
.
hparams
.
gpus
>
1
:
self
.
hparams
.
sortish_sampler
=
False
warnings
.
warn
(
"ignoring sortish_sampler as it is unsupported on multiple GPUs"
)
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
...
@@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer):
dataset
=
self
.
get_dataset
(
type_path
)
dataset
=
self
.
get_dataset
(
type_path
)
sampler
=
None
sampler
=
None
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
assert
self
.
hparams
.
gpus
<=
1
# this should never break because of the assertion in __init__
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
,
distributed
=
self
.
hparams
.
gpus
>
1
)
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
)
shuffle
=
False
shuffle
=
False
dataloader
=
DataLoader
(
dataloader
=
DataLoader
(
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
77950c48
...
@@ -149,9 +149,9 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -149,9 +149,9 @@ class TestSummarizationDistiller(unittest.TestCase):
no_teacher
=
True
,
no_teacher
=
True
,
freeze_encoder
=
True
,
freeze_encoder
=
True
,
gpus
=
2
,
gpus
=
2
,
sortish_sampler
=
Fals
e
,
sortish_sampler
=
Tru
e
,
)
)
self
.
_test_distiller_cli
(
updates
)
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
def
test_distill_no_teacher
(
self
):
def
test_distill_no_teacher
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
...
...
examples/seq2seq/utils.py
View file @
77950c48
import
itertools
import
itertools
import
json
import
json
import
linecache
import
linecache
import
math
import
os
import
os
import
pickle
import
pickle
from
logging
import
getLogger
from
logging
import
getLogger
...
@@ -10,6 +11,7 @@ from typing import Callable, Dict, Iterable, List, Union
...
@@ -10,6 +11,7 @@ from typing import Callable, Dict, Iterable, List, Union
import
git
import
git
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
from
rouge_score
import
rouge_scorer
,
scoring
from
rouge_score
import
rouge_scorer
,
scoring
from
sacrebleu
import
corpus_bleu
from
sacrebleu
import
corpus_bleu
from
torch
import
nn
from
torch
import
nn
...
@@ -111,8 +113,11 @@ class AbstractSeq2SeqDataset(Dataset):
...
@@ -111,8 +113,11 @@ class AbstractSeq2SeqDataset(Dataset):
def
get_char_lens
(
data_file
):
def
get_char_lens
(
data_file
):
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
def
make_sortish_sampler
(
self
,
batch_size
):
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
):
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
if
distributed
:
return
DistributedSortishSampler
(
self
,
batch_size
)
else
:
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
raise
NotImplementedError
(
"You must implement this"
)
raise
NotImplementedError
(
"You must implement this"
)
...
@@ -191,24 +196,77 @@ class SortishSampler(Sampler):
...
@@ -191,24 +196,77 @@ class SortishSampler(Sampler):
def
__init__
(
self
,
data
,
batch_size
):
def
__init__
(
self
,
data
,
batch_size
):
self
.
data
,
self
.
bs
=
data
,
batch_size
self
.
data
,
self
.
bs
=
data
,
batch_size
def
key
(
self
,
i
):
return
self
.
data
[
i
]
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
return
len
(
self
.
data
)
def
__iter__
(
self
):
def
__iter__
(
self
):
idxs
=
np
.
random
.
permutation
(
len
(
self
.
data
))
return
iter
(
sortish_sampler_indices
(
self
.
data
,
self
.
bs
))
sz
=
self
.
bs
*
50
ck_idx
=
[
idxs
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
idxs
),
sz
)]
sort_idx
=
np
.
concatenate
([
sorted
(
s
,
key
=
self
.
key
,
reverse
=
True
)
for
s
in
ck_idx
])
def
sortish_sampler_indices
(
data
:
List
,
bs
:
int
)
->
np
.
array
:
sz
=
self
.
bs
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
ck_idx
=
[
sort_idx
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
sort_idx
),
sz
)]
max_ck
=
np
.
argmax
([
self
.
key
(
ck
[
0
])
for
ck
in
ck_idx
])
# find the chunk with the largest key,
def
key_fn
(
i
):
ck_idx
[
0
],
ck_idx
[
max_ck
]
=
ck_idx
[
max_ck
],
ck_idx
[
0
]
# then make sure it goes first.
return
data
[
i
]
sort_idx
=
np
.
concatenate
(
np
.
random
.
permutation
(
ck_idx
[
1
:]))
if
len
(
ck_idx
)
>
1
else
np
.
array
([],
dtype
=
np
.
int
)
sort_idx
=
np
.
concatenate
((
ck_idx
[
0
],
sort_idx
))
idxs
=
np
.
random
.
permutation
(
len
(
data
))
return
iter
(
sort_idx
)
sz
=
bs
*
50
ck_idx
=
[
idxs
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
idxs
),
sz
)]
sort_idx
=
np
.
concatenate
([
sorted
(
s
,
key
=
key_fn
,
reverse
=
True
)
for
s
in
ck_idx
])
sz
=
bs
ck_idx
=
[
sort_idx
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
sort_idx
),
sz
)]
max_ck
=
np
.
argmax
([
key_fn
(
ck
[
0
])
for
ck
in
ck_idx
])
# find the chunk with the largest key,
ck_idx
[
0
],
ck_idx
[
max_ck
]
=
ck_idx
[
max_ck
],
ck_idx
[
0
]
# then make sure it goes first.
sort_idx
=
np
.
concatenate
(
np
.
random
.
permutation
(
ck_idx
[
1
:]))
if
len
(
ck_idx
)
>
1
else
np
.
array
([],
dtype
=
np
.
int
)
sort_idx
=
np
.
concatenate
((
ck_idx
[
0
],
sort_idx
))
return
sort_idx
class
DistributedSortishSampler
(
Sampler
):
"""Copied from torch DistributedSampler"""
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
):
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
num_replicas
=
dist
.
get_world_size
()
if
rank
is
None
:
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
rank
=
dist
.
get_rank
()
self
.
dataset
=
dataset
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
num_replicas
))
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
batch_size
=
batch_size
def
__iter__
(
self
)
->
Iterable
:
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
available_indices
=
self
.
get_indices_for_rank
()
# indices[self.rank: self.total_size: self.num_replicas]
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
available_indices
]
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
)
indices
=
[
available_indices
[
i
]
for
i
in
sortish_indices
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
get_indices_for_rank
(
self
)
->
np
.
array
:
indices
=
list
(
range
(
len
(
self
.
dataset
)))
# add extra samples to make it evenly divisible
indices
+=
indices
[:
(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
# subsample
available_indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
return
available_indices
def
__len__
(
self
):
return
self
.
num_samples
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
logger
=
getLogger
(
__name__
)
logger
=
getLogger
(
__name__
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment