Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
77950c48
"projects/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "2468016b42cf3aeda5ff9d569964aa8f4c4f33ae"
Unverified
Commit
77950c48
authored
Sep 10, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 10, 2020
Browse files
[wip/s2s] DistributedSortishSampler (#7056)
parent
51448673
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
24 deletions
+79
-24
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+3
-6
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+2
-2
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+74
-16
No files found.
examples/seq2seq/finetune.py
View file @
77950c48
...
...
@@ -3,7 +3,6 @@ import glob
import
logging
import
os
import
time
import
warnings
from
collections
import
defaultdict
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
...
...
@@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer):
default_val_metric
=
"rouge2"
def
__init__
(
self
,
hparams
,
**
kwargs
):
if
hparams
.
sortish_sampler
and
hparams
.
gpus
>
1
:
hparams
.
replace_sampler_ddp
=
False
super
().
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
,
**
kwargs
)
use_task_specific_params
(
self
.
model
,
"summarization"
)
save_git_info
(
self
.
hparams
.
output_dir
)
...
...
@@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer):
"val"
:
self
.
hparams
.
val_max_target_length
,
"test"
:
self
.
hparams
.
test_max_target_length
,
}
if
self
.
hparams
.
sortish_sampler
and
self
.
hparams
.
gpus
>
1
:
self
.
hparams
.
sortish_sampler
=
False
warnings
.
warn
(
"ignoring sortish_sampler as it is unsupported on multiple GPUs"
)
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
...
...
@@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer):
dataset
=
self
.
get_dataset
(
type_path
)
sampler
=
None
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
assert
self
.
hparams
.
gpus
<=
1
# this should never break because of the assertion in __init__
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
)
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
,
distributed
=
self
.
hparams
.
gpus
>
1
)
shuffle
=
False
dataloader
=
DataLoader
(
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
77950c48
...
...
@@ -149,9 +149,9 @@ class TestSummarizationDistiller(unittest.TestCase):
no_teacher
=
True
,
freeze_encoder
=
True
,
gpus
=
2
,
sortish_sampler
=
Fals
e
,
sortish_sampler
=
Tru
e
,
)
self
.
_test_distiller_cli
(
updates
)
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
def
test_distill_no_teacher
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
...
...
examples/seq2seq/utils.py
View file @
77950c48
import
itertools
import
json
import
linecache
import
math
import
os
import
pickle
from
logging
import
getLogger
...
...
@@ -10,6 +11,7 @@ from typing import Callable, Dict, Iterable, List, Union
import
git
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
rouge_score
import
rouge_scorer
,
scoring
from
sacrebleu
import
corpus_bleu
from
torch
import
nn
...
...
@@ -111,8 +113,11 @@ class AbstractSeq2SeqDataset(Dataset):
def
get_char_lens
(
data_file
):
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
def
make_sortish_sampler
(
self
,
batch_size
):
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
):
if
distributed
:
return
DistributedSortishSampler
(
self
,
batch_size
)
else
:
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
def
__getitem__
(
self
,
item
):
raise
NotImplementedError
(
"You must implement this"
)
...
...
@@ -191,24 +196,77 @@ class SortishSampler(Sampler):
def
__init__
(
self
,
data
,
batch_size
):
self
.
data
,
self
.
bs
=
data
,
batch_size
def
key
(
self
,
i
):
return
self
.
data
[
i
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
def
__iter__
(
self
):
idxs
=
np
.
random
.
permutation
(
len
(
self
.
data
))
sz
=
self
.
bs
*
50
ck_idx
=
[
idxs
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
idxs
),
sz
)]
sort_idx
=
np
.
concatenate
([
sorted
(
s
,
key
=
self
.
key
,
reverse
=
True
)
for
s
in
ck_idx
])
sz
=
self
.
bs
ck_idx
=
[
sort_idx
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
sort_idx
),
sz
)]
max_ck
=
np
.
argmax
([
self
.
key
(
ck
[
0
])
for
ck
in
ck_idx
])
# find the chunk with the largest key,
ck_idx
[
0
],
ck_idx
[
max_ck
]
=
ck_idx
[
max_ck
],
ck_idx
[
0
]
# then make sure it goes first.
sort_idx
=
np
.
concatenate
(
np
.
random
.
permutation
(
ck_idx
[
1
:]))
if
len
(
ck_idx
)
>
1
else
np
.
array
([],
dtype
=
np
.
int
)
sort_idx
=
np
.
concatenate
((
ck_idx
[
0
],
sort_idx
))
return
iter
(
sort_idx
)
return
iter
(
sortish_sampler_indices
(
self
.
data
,
self
.
bs
))
def
sortish_sampler_indices
(
data
:
List
,
bs
:
int
)
->
np
.
array
:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def
key_fn
(
i
):
return
data
[
i
]
idxs
=
np
.
random
.
permutation
(
len
(
data
))
sz
=
bs
*
50
ck_idx
=
[
idxs
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
idxs
),
sz
)]
sort_idx
=
np
.
concatenate
([
sorted
(
s
,
key
=
key_fn
,
reverse
=
True
)
for
s
in
ck_idx
])
sz
=
bs
ck_idx
=
[
sort_idx
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
sort_idx
),
sz
)]
max_ck
=
np
.
argmax
([
key_fn
(
ck
[
0
])
for
ck
in
ck_idx
])
# find the chunk with the largest key,
ck_idx
[
0
],
ck_idx
[
max_ck
]
=
ck_idx
[
max_ck
],
ck_idx
[
0
]
# then make sure it goes first.
sort_idx
=
np
.
concatenate
(
np
.
random
.
permutation
(
ck_idx
[
1
:]))
if
len
(
ck_idx
)
>
1
else
np
.
array
([],
dtype
=
np
.
int
)
sort_idx
=
np
.
concatenate
((
ck_idx
[
0
],
sort_idx
))
return
sort_idx
class
DistributedSortishSampler
(
Sampler
):
"""Copied from torch DistributedSampler"""
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
):
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
num_replicas
=
dist
.
get_world_size
()
if
rank
is
None
:
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
rank
=
dist
.
get_rank
()
self
.
dataset
=
dataset
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
num_replicas
))
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
batch_size
=
batch_size
def
__iter__
(
self
)
->
Iterable
:
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
available_indices
=
self
.
get_indices_for_rank
()
# indices[self.rank: self.total_size: self.num_replicas]
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
available_indices
]
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
)
indices
=
[
available_indices
[
i
]
for
i
in
sortish_indices
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
get_indices_for_rank
(
self
)
->
np
.
array
:
indices
=
list
(
range
(
len
(
self
.
dataset
)))
# add extra samples to make it evenly divisible
indices
+=
indices
[:
(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
# subsample
available_indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
return
available_indices
def
__len__
(
self
):
return
self
.
num_samples
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
logger
=
getLogger
(
__name__
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment