Unverified Commit 73fde1de authored by ctheodoris's avatar ctheodoris Committed by GitHub
Browse files

Faster list concat for trainer_pt_utils.get_length_grouped_indices() (#11825)



get_length_grouped_indices() in LengthGroupedSampler and DistributedLengthGroupedSampler
is prohibitively slow for large number of megabatches (in test case takes hours for ~270k
megabatches with 100 items each) due to slow list concatenation with sum(megabatches, []).

Resolves: #11795
Co-authored-by: default avatarctheodoris <cvtheodo@ds.dfci.harvard.edu>
parent da22245e
......@@ -495,7 +495,7 @@ def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, genera
# Switch to put the longest element in first position
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
return sum(megabatches, [])
return [i for megabatch in megabatches for i in megabatch]
class LengthGroupedSampler(Sampler):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment