Unverified Commit 1b74af76 authored by Zhaofeng Wu's avatar Zhaofeng Wu Committed by GitHub
Browse files

Allow dataset to be an optional argument for (Distributed)LengthGroupedSampler (#13820)

* Allow dataset to be an optional argument for (Distributed)LengthGroupedSampler

* Fix
parent d4e4efce
...@@ -572,16 +572,16 @@ class Trainer: ...@@ -572,16 +572,16 @@ class Trainer:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1: if self.args.world_size <= 1:
return LengthGroupedSampler( return LengthGroupedSampler(
self.train_dataset,
self.args.train_batch_size, self.args.train_batch_size,
dataset=self.train_dataset,
lengths=lengths, lengths=lengths,
model_input_name=model_input_name, model_input_name=model_input_name,
generator=generator, generator=generator,
) )
else: else:
return DistributedLengthGroupedSampler( return DistributedLengthGroupedSampler(
self.train_dataset,
self.args.train_batch_size, self.args.train_batch_size,
dataset=self.train_dataset,
num_replicas=self.args.world_size, num_replicas=self.args.world_size,
rank=self.args.process_index, rank=self.args.process_index,
lengths=lengths, lengths=lengths,
......
...@@ -520,25 +520,27 @@ class LengthGroupedSampler(Sampler): ...@@ -520,25 +520,27 @@ class LengthGroupedSampler(Sampler):
def __init__( def __init__(
self, self,
dataset: Dataset,
batch_size: int, batch_size: int,
dataset: Optional[Dataset] = None,
lengths: Optional[List[int]] = None, lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None, model_input_name: Optional[str] = None,
generator=None, generator=None,
): ):
self.dataset = dataset if dataset is None and lengths is None:
raise ValueError("One of dataset and lengths must be provided.")
self.batch_size = batch_size self.batch_size = batch_size
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None: if lengths is None:
model_input_name = model_input_name if model_input_name is not None else "input_ids"
if ( if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0] or model_input_name not in dataset[0]
): ):
raise ValueError( raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an " "Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key." f"'{model_input_name}' key."
) )
lengths = [len(feature[self.model_input_name]) for feature in dataset] lengths = [len(feature[model_input_name]) for feature in dataset]
self.lengths = lengths self.lengths = lengths
self.generator = generator self.generator = generator
...@@ -558,8 +560,8 @@ class DistributedLengthGroupedSampler(DistributedSampler): ...@@ -558,8 +560,8 @@ class DistributedLengthGroupedSampler(DistributedSampler):
# Copied and adapted from PyTorch DistributedSampler. # Copied and adapted from PyTorch DistributedSampler.
def __init__( def __init__(
self, self,
dataset: Dataset,
batch_size: int, batch_size: int,
dataset: Optional[Dataset] = None,
num_replicas: Optional[int] = None, num_replicas: Optional[int] = None,
rank: Optional[int] = None, rank: Optional[int] = None,
seed: int = 0, seed: int = 0,
...@@ -567,6 +569,8 @@ class DistributedLengthGroupedSampler(DistributedSampler): ...@@ -567,6 +569,8 @@ class DistributedLengthGroupedSampler(DistributedSampler):
lengths: Optional[List[int]] = None, lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None, model_input_name: Optional[str] = None,
): ):
if dataset is None and lengths is None:
raise ValueError("One of dataset and lengths must be provided.")
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
...@@ -575,37 +579,38 @@ class DistributedLengthGroupedSampler(DistributedSampler): ...@@ -575,37 +579,38 @@ class DistributedLengthGroupedSampler(DistributedSampler):
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank() rank = dist.get_rank()
self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.rank = rank self.rank = rank
self.epoch = 0 self.epoch = 0
self.drop_last = drop_last self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None: if lengths is None:
model_input_name = model_input_name if model_input_name is not None else "input_ids"
if ( if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0] or model_input_name not in dataset[0]
): ):
raise ValueError( raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an " "Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key." f"'{model_input_name}' key."
) )
lengths = [len(feature[self.model_input_name]) for feature in dataset] lengths = [len(feature[model_input_name]) for feature in dataset]
self.lengths = lengths self.lengths = lengths
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.lengths) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(self.lengths) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
# Deterministically shuffle based on epoch and seed # Deterministically shuffle based on epoch and seed
g = torch.Generator() g = torch.Generator()
......
...@@ -181,7 +181,7 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -181,7 +181,7 @@ class TrainerUtilsTest(unittest.TestCase):
# Put one bigger than the others to check it ends up in first position # Put one bigger than the others to check it ends up in first position
lengths[32] = 50 lengths[32] = 50
indices = list(LengthGroupedSampler(lengths, 4, lengths=lengths)) indices = list(LengthGroupedSampler(4, lengths=lengths))
# The biggest element should be first # The biggest element should be first
self.assertEqual(lengths[indices[0]], 50) self.assertEqual(lengths[indices[0]], 50)
# The indices should be a permutation of range(100) # The indices should be a permutation of range(100)
...@@ -196,7 +196,7 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -196,7 +196,7 @@ class TrainerUtilsTest(unittest.TestCase):
# Put one bigger than the others to check it ends up in first position # Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist() data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
indices = list(LengthGroupedSampler(data, 4)) indices = list(LengthGroupedSampler(4, dataset=data))
# The biggest element should be first # The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105) self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6) # The indices should be a permutation of range(6)
...@@ -211,7 +211,7 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -211,7 +211,7 @@ class TrainerUtilsTest(unittest.TestCase):
# Put one bigger than the others to check it ends up in first position # Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist() data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
indices = list(LengthGroupedSampler(data, 4)) indices = list(LengthGroupedSampler(4, dataset=data))
# The biggest element should be first # The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105) self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6) # The indices should be a permutation of range(6)
...@@ -223,8 +223,8 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -223,8 +223,8 @@ class TrainerUtilsTest(unittest.TestCase):
# Put one bigger than the others to check it ends up in first position # Put one bigger than the others to check it ends up in first position
lengths[32] = 50 lengths[32] = 50
indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths)) indices_process_0 = list(DistributedLengthGroupedSampler(4, num_replicas=2, rank=0, lengths=lengths))
indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths)) indices_process_1 = list(DistributedLengthGroupedSampler(4, num_replicas=2, rank=1, lengths=lengths))
# The biggest element should be first # The biggest element should be first
self.assertEqual(lengths[indices_process_0[0]], 50) self.assertEqual(lengths[indices_process_0[0]], 50)
# The indices should be a permutation of range(100) # The indices should be a permutation of range(100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment