Unverified Commit 34791613 authored by Steven Madere's avatar Steven Madere Committed by GitHub
Browse files

Fix type hint for train_dataset param of Trainer.__init__() to allow...

Fix type hint for train_dataset param of Trainer.__init__() to allow IterableDataset.  Issue 29678 (#29738)

* Fixed typehint for train_dataset param in Trainer.__init__().  Added IterableDataset option.

* make fixup
parent e68ff304
......@@ -52,7 +52,7 @@ import torch.distributed as dist
from huggingface_hub import ModelCard, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from . import __version__
from .configuration_utils import PretrainedConfig
......@@ -353,7 +353,7 @@ class Trainer:
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment