Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
19e4ebbe
Commit
19e4ebbe
authored
Oct 02, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 03, 2019
Browse files
grouped_batch_sampler
parent
594202a9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
105 additions
and
0 deletions
+105
-0
examples/distillation/grouped_batch_sampler.py
examples/distillation/grouped_batch_sampler.py
+105
-0
No files found.
examples/distillation/grouped_batch_sampler.py
0 → 100644
View file @
19e4ebbe
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Adapted from PyTorch Vision (https://github.com/pytorch/vision/blob/master/references/detection/group_by_aspect_ratio.py)
"""
import
bisect
import
copy
from
collections
import
defaultdict
import
numpy
as
np
from
torch.utils.data.sampler
import
BatchSampler
,
Sampler
from
utils
import
logger
def
_quantize
(
x
,
bins
):
bins
=
copy
.
deepcopy
(
bins
)
bins
=
sorted
(
bins
)
quantized
=
list
(
map
(
lambda
y
:
bisect
.
bisect_right
(
bins
,
y
),
x
))
return
quantized
def
create_lengths_groups
(
lengths
,
k
=
0
):
bins
=
np
.
arange
(
start
=
3
,
stop
=
k
,
step
=
4
).
tolist
()
if
k
>
0
else
[
10
]
groups
=
_quantize
(
lengths
,
bins
)
# count number of elements per group
counts
=
np
.
unique
(
groups
,
return_counts
=
True
)[
1
]
fbins
=
[
0
]
+
bins
+
[
np
.
inf
]
logger
.
info
(
"Using {} as bins for aspect lengths quantization"
.
format
(
fbins
))
logger
.
info
(
"Count of instances per bin: {}"
.
format
(
counts
))
return
groups
class
GroupedBatchSampler
(
BatchSampler
):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that the batch only contain elements from the same group.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
Arguments:
sampler (Sampler): Base sampler.
group_ids (list[int]): If the sampler produces indices in range [0, N),
`group_ids` must be a list of `N` ints which contains the group id of each sample.
The group ids must be a continuous set of integers starting from
0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch.
"""
def
__init__
(
self
,
sampler
,
group_ids
,
batch_size
):
if
not
isinstance
(
sampler
,
Sampler
):
raise
ValueError
(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.
format
(
sampler
)
)
self
.
sampler
=
sampler
self
.
group_ids
=
group_ids
self
.
batch_size
=
batch_size
def
__iter__
(
self
):
buffer_per_group
=
defaultdict
(
list
)
samples_per_group
=
defaultdict
(
list
)
num_batches
=
0
for
idx
in
self
.
sampler
:
group_id
=
self
.
group_ids
[
idx
]
buffer_per_group
[
group_id
].
append
(
idx
)
samples_per_group
[
group_id
].
append
(
idx
)
if
len
(
buffer_per_group
[
group_id
])
==
self
.
batch_size
:
yield
buffer_per_group
[
group_id
]
#TODO
num_batches
+=
1
del
buffer_per_group
[
group_id
]
assert
len
(
buffer_per_group
[
group_id
])
<
self
.
batch_size
# now we have run out of elements that satisfy
# the group criteria, let's return the remaining
# elements so that the size of the sampler is
# deterministic
expected_num_batches
=
len
(
self
)
num_remaining
=
expected_num_batches
-
num_batches
if
num_remaining
>
0
:
# for the remaining batches, group the batches by similar lengths
batch_idx
=
[]
for
group_id
,
idxs
in
sorted
(
buffer_per_group
.
items
(),
key
=
lambda
x
:
x
[
0
]):
batch_idx
.
extend
(
idxs
)
if
len
(
batch_idx
)
>=
self
.
batch_size
:
yield
batch_idx
[:
self
.
batch_size
]
batch_idx
=
batch_idx
[
self
.
batch_size
:]
num_remaining
-=
1
if
len
(
batch_idx
)
>
0
:
yield
batch_idx
num_remaining
-=
1
assert
num_remaining
==
0
def
__len__
(
self
):
"""
Return the number of mini-batches rather than the number of samples.
"""
return
(
len
(
self
.
sampler
)
+
self
.
batch_size
-
1
)
//
self
.
batch_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment