"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2fb9a934b4a7ba46f58db9785c50892ef9e6e5c3"
Unverified Commit af4c0262 authored by Matthew Hoffman's avatar Matthew Hoffman Committed by GitHub
Browse files

Add datasets.Dataset to Trainer's train_dataset and eval_dataset type hints (#30077)

* Add datasets.Dataset to Trainer's train_dataset and eval_dataset type hints

* Add is_datasets_available check for importing datasets under TYPE_CHECKING guard

https://github.com/huggingface/transformers/pull/30077/files#r1555939352
parent 4e3490f7
...@@ -250,6 +250,8 @@ def _get_fsdp_ckpt_kwargs(): ...@@ -250,6 +250,8 @@ def _get_fsdp_ckpt_kwargs():
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
if is_datasets_available():
import datasets
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -287,7 +289,7 @@ class Trainer: ...@@ -287,7 +289,7 @@ class Trainer:
The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
default to [`default_data_collator`] if no `tokenizer` is provided, an instance of default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
[`DataCollatorWithPadding`] otherwise. [`DataCollatorWithPadding`] otherwise.
train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*):
The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. `model.forward()` method are automatically removed.
...@@ -296,7 +298,7 @@ class Trainer: ...@@ -296,7 +298,7 @@ class Trainer:
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
sets the seed of the RNGs used. sets the seed of the RNGs used.
eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*):
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
dataset prepending the dictionary key to the metric name. dataset prepending the dictionary key to the metric name.
...@@ -358,8 +360,8 @@ class Trainer: ...@@ -358,8 +360,8 @@ class Trainer:
model: Union[PreTrainedModel, nn.Module] = None, model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None, args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None, data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None, train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment