Unverified Commit af4c0262 authored by Matthew Hoffman's avatar Matthew Hoffman Committed by GitHub
Browse files

Add datasets.Dataset to Trainer's train_dataset and eval_dataset type hints (#30077)

* Add datasets.Dataset to Trainer's train_dataset and eval_dataset type hints

* Add is_datasets_available check for importing datasets under TYPE_CHECKING guard

https://github.com/huggingface/transformers/pull/30077/files#r1555939352
parent 4e3490f7
......@@ -250,6 +250,8 @@ def _get_fsdp_ckpt_kwargs():
if TYPE_CHECKING:
import optuna
if is_datasets_available():
import datasets
logger = logging.get_logger(__name__)
......@@ -287,7 +289,7 @@ class Trainer:
The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
[`DataCollatorWithPadding`] otherwise.
train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*):
The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
......@@ -296,7 +298,7 @@ class Trainer:
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
sets the seed of the RNGs used.
eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*):
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
dataset prepending the dictionary key to the metric name.
......@@ -358,8 +360,8 @@ class Trainer:
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment