Unverified Commit ebd94b0f authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

🚨🚨🚨 Replace DataLoader logic for Accelerate in Trainer, remove unneeded tests 🚨🚨🚨 (#24028)

* Working integration

* Fix failing test

* Revert label host logic

* Bring it back!
parent dc42a9d7
...@@ -61,7 +61,6 @@ from huggingface_hub import Repository, create_repo ...@@ -61,7 +61,6 @@ from huggingface_hub import Repository, create_repo
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from . import __version__ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -73,7 +72,7 @@ from .modelcard import TrainingSummary ...@@ -73,7 +72,7 @@ from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
...@@ -85,14 +84,11 @@ from .trainer_callback import ( ...@@ -85,14 +84,11 @@ from .trainer_callback import (
TrainerState, TrainerState,
) )
from .trainer_pt_utils import ( from .trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer, DistributedTensorGatherer,
IterableDatasetShard, IterableDatasetShard,
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
SequentialDistributedSampler, SequentialDistributedSampler,
ShardSampler,
distributed_broadcast_scalars, distributed_broadcast_scalars,
distributed_concat, distributed_concat,
find_batch_size, find_batch_size,
...@@ -102,7 +98,6 @@ from .trainer_pt_utils import ( ...@@ -102,7 +98,6 @@ from .trainer_pt_utils import (
nested_concat, nested_concat,
nested_detach, nested_detach,
nested_numpify, nested_numpify,
nested_truncate,
nested_xla_mesh_reduce, nested_xla_mesh_reduce,
reissue_pt_warnings, reissue_pt_warnings,
) )
...@@ -812,20 +807,6 @@ class Trainer: ...@@ -812,20 +807,6 @@ class Trainer:
if self.train_dataset is None or not has_length(self.train_dataset): if self.train_dataset is None or not has_length(self.train_dataset):
return None return None
generator = None
if self.args.world_size <= 1:
generator = torch.Generator()
# for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
# `args.seed`) if data_seed isn't provided.
# Further on in this method, we default to `args.seed` instead.
if self.args.data_seed is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
else:
seed = self.args.data_seed
generator.manual_seed(seed)
seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
# Build the sampler. # Build the sampler.
if self.args.group_by_length: if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
...@@ -837,47 +818,15 @@ class Trainer: ...@@ -837,47 +818,15 @@ class Trainer:
else: else:
lengths = None lengths = None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1: return LengthGroupedSampler(
return LengthGroupedSampler( self.args.train_batch_size * self.args.gradient_accumulation_steps,
self.args.train_batch_size * self.args.gradient_accumulation_steps, dataset=self.train_dataset,
dataset=self.train_dataset, lengths=lengths,
lengths=lengths, model_input_name=model_input_name,
model_input_name=model_input_name, )
generator=generator,
)
else:
return DistributedLengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
lengths=lengths,
model_input_name=model_input_name,
seed=seed,
)
else: else:
if self.args.world_size <= 1: return RandomSampler(self.train_dataset)
return RandomSampler(self.train_dataset, generator=generator)
elif (
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
and not self.args.dataloader_drop_last
):
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=seed,
)
else:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=seed,
)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
""" """
...@@ -898,36 +847,19 @@ class Trainer: ...@@ -898,36 +847,19 @@ class Trainer:
else: else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training") data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
if isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params = {
if self.args.world_size > 1: "batch_size": self._train_batch_size,
train_dataset = IterableDatasetShard( "collate_fn": data_collator,
train_dataset, "num_workers": self.args.dataloader_num_workers,
batch_size=self._train_batch_size, "pin_memory": self.args.dataloader_pin_memory,
drop_last=self.args.dataloader_drop_last, }
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
return DataLoader( if not isinstance(train_dataset, torch.utils.data.IterableDataset):
train_dataset, dataloader_params["sampler"] = self._get_train_sampler()
batch_size=self._train_batch_size, dataloader_params["drop_last"] = self.args.dataloader_drop_last
collate_fn=data_collator, dataloader_params["worker_init_fn"] = seed_worker
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
train_sampler = self._get_train_sampler() return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return DataLoader(
train_dataset,
batch_size=self._train_batch_size,
sampler=train_sampler,
collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
worker_init_fn=seed_worker,
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code # Deprecated code
...@@ -943,20 +875,13 @@ class Trainer: ...@@ -943,20 +875,13 @@ class Trainer:
rank=smp.dp_rank(), rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size, batch_size=self.args.per_device_eval_batch_size,
) )
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
return SequentialDistributedSampler(eval_dataset)
else: else:
return SequentialSampler(eval_dataset) return SequentialSampler(eval_dataset)
if self.args.world_size <= 1: if self.args.world_size <= 1:
return SequentialSampler(eval_dataset) return SequentialSampler(eval_dataset)
else: else:
return ShardSampler( return None
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
""" """
...@@ -979,34 +904,18 @@ class Trainer: ...@@ -979,34 +904,18 @@ class Trainer:
else: else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
if isinstance(eval_dataset, torch.utils.data.IterableDataset): dataloader_params = {
if self.args.world_size > 1: "batch_size": self.args.eval_batch_size,
eval_dataset = IterableDatasetShard( "collate_fn": data_collator,
eval_dataset, "num_workers": self.args.dataloader_num_workers,
batch_size=self.args.per_device_eval_batch_size, "pin_memory": self.args.dataloader_pin_memory,
drop_last=self.args.dataloader_drop_last, }
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
return DataLoader(
eval_dataset,
batch_size=self.args.eval_batch_size,
collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
eval_sampler = self._get_eval_sampler(eval_dataset) if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader( return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
eval_dataset,
sampler=eval_sampler,
batch_size=self.args.eval_batch_size,
collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
""" """
...@@ -1026,35 +935,19 @@ class Trainer: ...@@ -1026,35 +935,19 @@ class Trainer:
else: else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="test") data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
if isinstance(test_dataset, torch.utils.data.IterableDataset): dataloader_params = {
if self.args.world_size > 1: "batch_size": self.args.eval_batch_size,
test_dataset = IterableDatasetShard( "collate_fn": data_collator,
test_dataset, "num_workers": self.args.dataloader_num_workers,
batch_size=self.args.eval_batch_size, "pin_memory": self.args.dataloader_pin_memory,
drop_last=self.args.dataloader_drop_last, }
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
return DataLoader(
test_dataset,
batch_size=self.args.eval_batch_size,
collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
test_sampler = self._get_eval_sampler(test_dataset) if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
# We use the same batch_size as for eval. # We use the same batch_size as for eval.
return DataLoader( return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
test_dataset,
sampler=test_sampler,
batch_size=self.args.eval_batch_size,
collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
def create_optimizer_and_scheduler(self, num_training_steps: int): def create_optimizer_and_scheduler(self, num_training_steps: int):
""" """
...@@ -1864,26 +1757,11 @@ class Trainer: ...@@ -1864,26 +1757,11 @@ class Trainer:
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip: if not args.ignore_data_skip:
for epoch in range(epochs_trained): for epoch in range(epochs_trained):
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( for _ in train_dataloader:
train_dataloader.sampler, RandomSampler break
)
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
# That was before PyTorch 1.11 however...
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
_ = list(train_dataloader.sampler)
total_batched_samples = 0 total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
train_dataloader.dataset.set_epoch(epoch)
if is_torch_tpu_available(): if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = parallel_loader epoch_iterator = parallel_loader
...@@ -3250,27 +3128,29 @@ class Trainer: ...@@ -3250,27 +3128,29 @@ class Trainer:
# Update containers on host # Update containers on host
if loss is not None: if loss is not None:
losses = self._nested_gather(loss.repeat(batch_size)) losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
if labels is not None: if labels is not None:
labels = self._pad_across_processes(labels) labels = self.accelerator.pad_across_processes(labels)
if inputs_decode is not None: if inputs_decode is not None:
inputs_decode = self._pad_across_processes(inputs_decode) inputs_decode = self.accelerator.pad_across_processes(inputs_decode)
inputs_decode = self._nested_gather(inputs_decode) inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
inputs_host = ( inputs_host = (
inputs_decode inputs_decode
if inputs_host is None if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100) else nested_concat(inputs_host, inputs_decode, padding_index=-100)
) )
if logits is not None: if logits is not None:
logits = self._pad_across_processes(logits) logits = self.accelerator.pad_across_processes(logits)
if self.preprocess_logits_for_metrics is not None: if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self._nested_gather(logits) logits = self.accelerator.gather_for_metrics((logits))
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None: if labels is not None:
labels = self._nested_gather(labels) labels = self.accelerator.gather_for_metrics((labels))
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
...@@ -3303,19 +3183,13 @@ class Trainer: ...@@ -3303,19 +3183,13 @@ class Trainer:
# Gather all remaining tensors and put them back on the CPU # Gather all remaining tensors and put them back on the CPU
if losses_host is not None: if losses_host is not None:
losses = nested_numpify(losses_host) all_losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None: if preds_host is not None:
logits = nested_numpify(preds_host) all_preds = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None: if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host) all_inputs = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None: if labels_host is not None:
labels = nested_numpify(labels_host) all_labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
# Number of samples # Number of samples
if has_length(eval_dataset): if has_length(eval_dataset):
...@@ -3332,17 +3206,6 @@ class Trainer: ...@@ -3332,17 +3206,6 @@ class Trainer:
if num_samples == 0 and observed_num_examples > 0: if num_samples == 0 and observed_num_examples > 0:
num_samples = observed_num_examples num_samples = observed_num_examples
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
if all_losses is not None:
all_losses = all_losses[:num_samples]
if all_preds is not None:
all_preds = nested_truncate(all_preds, num_samples)
if all_labels is not None:
all_labels = nested_truncate(all_labels, num_samples)
if all_inputs is not None:
all_inputs = nested_truncate(all_inputs, num_samples)
# Metrics! # Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None: if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
if args.include_inputs_for_metrics: if args.include_inputs_for_metrics:
......
...@@ -798,9 +798,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -798,9 +798,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def test_train_and_eval_dataloaders(self): def test_train_and_eval_dataloaders(self):
n_gpu = max(1, torch.cuda.device_count()) n_gpu = max(1, torch.cuda.device_count())
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16) trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
self.assertEqual(trainer.get_train_dataloader().batch_size, 16 * n_gpu) self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu)
trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16) trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
self.assertEqual(trainer.get_eval_dataloader().batch_size, 16 * n_gpu) self.assertEqual(trainer.get_eval_dataloader().total_batch_size, 16 * n_gpu)
# Check drop_last works # Check drop_last works
trainer = get_regression_trainer( trainer = get_regression_trainer(
...@@ -833,67 +833,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -833,67 +833,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train() trainer.train()
trainer.evaluate() trainer.evaluate()
def test_sampler_seed(self):
# nb: we don't want to inherit from IterableDataset to hit the right code path
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, length: int = 101):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, i):
if (i < 0) or (i >= self.length):
raise IndexError
return {"input_ids": [i]}
class DummyModel(PreTrainedModel):
def __init__(self, num_params: int):
super().__init__(PretrainedConfig())
# Add some (unused) params. the point here is that randomness in model_init shouldn't influence
# data loader order.
self.params = nn.Parameter(torch.randn(num_params))
def forward(self, input_ids, labels=None):
if labels is not None:
return torch.tensor(0.0, device=input_ids.device), input_ids
else:
return input_ids
def _get_first_data_sample(num_params, seed, data_seed, **kwargs):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = Trainer(
model_init=lambda: DummyModel(num_params),
args=TrainingArguments(
output_dir=tmpdir,
**kwargs,
seed=seed,
data_seed=data_seed,
local_rank=-1,
),
train_dataset=DummyDataset(),
)
return next(iter(trainer.get_train_dataloader()))
# test that the seed is passed to the sampler
# the codepath we want to hit is world_size <= 1, and both group_by_length
for group_by_length in [True, False]:
sample42_1 = _get_first_data_sample(num_params=10, seed=42, data_seed=42, group_by_length=group_by_length)
sample42_2 = _get_first_data_sample(num_params=11, seed=42, data_seed=42, group_by_length=group_by_length)
self.assertTrue(torch.equal(sample42_1["input_ids"], sample42_2["input_ids"]))
# should get same samples with different seed, so long as data_seed is the same
sample42_3 = _get_first_data_sample(num_params=11, seed=11, data_seed=42, group_by_length=group_by_length)
self.assertTrue(torch.equal(sample42_1["input_ids"], sample42_3["input_ids"]))
# make sure we have some randomness in the samples if data_seed is different
others = [
_get_first_data_sample(num_params=i, seed=42, data_seed=i, group_by_length=group_by_length)
for i in range(10)
]
self.assertTrue(any(not torch.equal(sample42_1["input_ids"], sample["input_ids"]) for sample in others))
@require_torch_multi_gpu @require_torch_multi_gpu
def test_data_is_not_parallelized_when_model_is_parallel(self): def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel() model = RegressionModel()
...@@ -907,9 +846,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -907,9 +846,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.args.n_gpu, 1) self.assertEqual(trainer.args.n_gpu, 1)
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu # The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
self.assertEqual(trainer.get_train_dataloader().batch_size, 16) self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16)
self.assertEqual(len(trainer.get_train_dataloader()), 64 // 16) self.assertEqual(len(trainer.get_train_dataloader()), 64 // 16)
self.assertEqual(trainer.get_eval_dataloader().batch_size, 16) self.assertEqual(trainer.get_eval_dataloader().total_batch_size, 16)
self.assertEqual(len(trainer.get_eval_dataloader()), 64 // 16) self.assertEqual(len(trainer.get_eval_dataloader()), 64 // 16)
def test_evaluate(self): def test_evaluate(self):
...@@ -1742,26 +1681,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1742,26 +1681,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertIsInstance(loader, torch.utils.data.DataLoader) self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler) self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
def test_training_finite_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
batch_size = 1
num_samples = 10
available_steps = num_samples // batch_size
data = FiniteIterableDataset(length=num_samples)
train_args = TrainingArguments(
"..",
max_steps=available_steps + 1, # set a higher number than actually available
per_device_train_batch_size=batch_size,
)
trainer = Trainer(model, train_dataset=data, args=train_args)
with self.assertLogs("transformers.trainer", level="WARNING") as logs:
trainer.train()
self.assertIn(f"stopping training at step {available_steps}!", logs.output[0])
def test_evaluation_iterable_dataset(self): def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5) config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config) model = RegressionPreTrainedModel(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment