# coding=utf-8 # Copyright 2021 The OneFlow Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import itertools # import unittest # import oneflow.utils.data as flowdata # from libai.data.samplers import CyclicSampler, SingleRoundSampler # class TestCyclicSampler(unittest.TestCase): # def test_cyclic_sampler_iterable(self): # sampler = CyclicSampler( # list(range(100)), # micro_batch_size=4, # shuffle=True, # consumed_samples=0, # seed=123, # ) # output_iter = itertools.islice(sampler, 25) # iteration=100/4=25 # sample_output = list() # for batch in output_iter: # sample_output.extend(batch) # self.assertEqual(set(sample_output), set(range(100))) # data_sampler = CyclicSampler( # list(range(100)), # micro_batch_size=4, # shuffle=True, # consumed_samples=0, # seed=123, # ) # data_loader = flowdata.DataLoader( # list(range(100)), batch_sampler=data_sampler, num_workers=0, collate_fn=lambda x: x # ) # data_loader_iter = itertools.islice(data_loader, 25) # output = list() # for data in data_loader_iter: # output.extend(data) # self.assertEqual(output, sample_output) # def test_cyclic_sampler_seed(self): # sampler = CyclicSampler( # list(range(100)), # micro_batch_size=4, # shuffle=True, # seed=123, # ) # data = list(itertools.islice(sampler, 65)) # sampler = CyclicSampler( # list(range(100)), # micro_batch_size=4, # shuffle=True, # seed=123, # ) # data2 = list(itertools.islice(sampler, 65)) # self.assertEqual(data, data2) # def test_cyclic_sampler_resume(self): # # Single rank # sampler = CyclicSampler( # list(range(10)), # micro_batch_size=4, # shuffle=True, # seed=123, # ) # all_output = list(itertools.islice(sampler, 50)) # iteration 50 times # sampler = CyclicSampler( # list(range(10)), # micro_batch_size=4, # shuffle=True, # seed=123, # consumed_samples=4 * 11, # consumed 11 iters # ) # resume_output = list(itertools.islice(sampler, 39)) # self.assertEqual(all_output[11:], resume_output) # def test_cyclic_sampler_resume_multi_rank(self): # # Multiple ranks # sampler_rank0 = CyclicSampler( # list(range(10)), # micro_batch_size=4, # shuffle=True, # seed=123, # data_parallel_rank=0, # data_parallel_size=2, # ) # sampler_rank1 = CyclicSampler( # list(range(10)), # micro_batch_size=4, # shuffle=True, # seed=123, # data_parallel_rank=1, # data_parallel_size=2, # ) # all_output_rank0 = list(itertools.islice(sampler_rank0, 50)) # iteration 50 times # all_output_rank1 = list(itertools.islice(sampler_rank1, 50)) # iteration 50 times # sampler_rank0 = CyclicSampler( # list(range(10)), # micro_batch_size=4, # shuffle=True, # seed=123, # data_parallel_rank=0, # data_parallel_size=2, # consumed_samples=4 * 11, # consumed 11 iters # ) # sampler_rank1 = CyclicSampler( # list(range(10)), # micro_batch_size=4, # shuffle=True, # seed=123, # data_parallel_rank=1, # data_parallel_size=2, # consumed_samples=4 * 11, # consumed 11 iters # ) # resume_output_rank0 = list(itertools.islice(sampler_rank0, 39)) # resume_output_rank1 = list(itertools.islice(sampler_rank1, 39)) # self.assertEqual(all_output_rank0[11:], resume_output_rank0) # self.assertEqual(all_output_rank1[11:], resume_output_rank1) # class TestSingleRoundSampler(unittest.TestCase): # def test_single_sampler_iterable(self): # sampler = SingleRoundSampler( # list(range(100)), # micro_batch_size=4, # shuffle=False, # ) # output_iter = itertools.islice(sampler, 30) # exceed iteration number # sample_output = list() # for batch in output_iter: # sample_output.extend(batch) # self.assertEqual(sample_output, list(range(100))) # def test_single_sampler_multi_rank(self): # sampler_rank0 = SingleRoundSampler( # list(range(101)), # micro_batch_size=4, # shuffle=False, # data_parallel_rank=0, # data_parallel_size=2, # ) # sampler_rank1 = SingleRoundSampler( # list(range(101)), # micro_batch_size=4, # shuffle=False, # data_parallel_rank=1, # data_parallel_size=2, # ) # output_iter_rank0 = itertools.islice(sampler_rank0, 30) # sample_output_rank0 = list() # for batch in output_iter_rank0: # sample_output_rank0.extend(batch) # output_iter_rank1 = itertools.islice(sampler_rank1, 30) # sample_output_rank1 = list() # for batch in output_iter_rank1: # sample_output_rank1.extend(batch) # # Padding 0 if it's not enough for a batch, otherwise `to_global` # # will raise errors for imbalanced data shape in different ranks # self.assertEqual(sample_output_rank0, list(range(51))) # self.assertEqual(sample_output_rank1, list(range(51, 101)) + [0]) # if __name__ == "__main__": # unittest.main()