Unverified Commit c2cd02ac authored by Takuya Makino's avatar Takuya Makino Committed by GitHub
Browse files

Accepts BatchEncoding in LengthSampler (#11431)

parent 30ede899
...@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .tokenization_utils_base import BatchEncoding
from .utils import logging from .utils import logging
...@@ -514,7 +515,10 @@ class LengthGroupedSampler(Sampler): ...@@ -514,7 +515,10 @@ class LengthGroupedSampler(Sampler):
self.batch_size = batch_size self.batch_size = batch_size
self.model_input_name = model_input_name if model_input_name is not None else "input_ids" self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None: if lengths is None:
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]: if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0]
):
raise ValueError( raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an " "Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key." f"'{self.model_input_name}' key."
...@@ -575,7 +579,10 @@ class DistributedLengthGroupedSampler(DistributedSampler): ...@@ -575,7 +579,10 @@ class DistributedLengthGroupedSampler(DistributedSampler):
self.model_input_name = model_input_name if model_input_name is not None else "input_ids" self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None: if lengths is None:
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]: if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0]
):
raise ValueError( raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an " "Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key." f"'{self.model_input_name}' key."
......
...@@ -27,6 +27,7 @@ if is_torch_available(): ...@@ -27,6 +27,7 @@ if is_torch_available():
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
from transformers.modeling_outputs import SequenceClassifierOutput from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.tokenization_utils_base import BatchEncoding
from transformers.trainer_pt_utils import ( from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler, DistributedLengthGroupedSampler,
DistributedSamplerWithLoop, DistributedSamplerWithLoop,
...@@ -185,6 +186,36 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -185,6 +186,36 @@ class TrainerUtilsTest(unittest.TestCase):
# The indices should be a permutation of range(100) # The indices should be a permutation of range(100)
self.assertEqual(list(sorted(indices)), list(range(100))) self.assertEqual(list(sorted(indices)), list(range(100)))
def test_group_by_length_with_dict(self):
# Get some inputs of random lengths
data = []
for _ in range(6):
input_ids = torch.randint(0, 25, (100,)).tolist()
data.append({"input_ids": input_ids})
# Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
indices = list(LengthGroupedSampler(data, 4))
# The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6)
self.assertEqual(list(sorted(indices)), list(range(6)))
def test_group_by_length_with_batch_encoding(self):
# Get some inputs of random lengths
data = []
for _ in range(6):
input_ids = torch.randint(0, 25, (100,)).tolist()
data.append(BatchEncoding({"input_ids": input_ids}))
# Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
indices = list(LengthGroupedSampler(data, 4))
# The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6)
self.assertEqual(list(sorted(indices)), list(range(6)))
def test_distributed_length_grouped(self): def test_distributed_length_grouped(self):
# Get some inputs of random lengths # Get some inputs of random lengths
lengths = torch.randint(0, 25, (100,)).tolist() lengths = torch.randint(0, 25, (100,)).tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment