Unverified Commit 3ac040bc authored by Julian Mack's avatar Julian Mack Committed by GitHub
Browse files

Updated Trainer args typing (#20655)

parent 3994c045
...@@ -298,13 +298,13 @@ class Trainer: ...@@ -298,13 +298,13 @@ class Trainer:
args: TrainingArguments = None, args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None, data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None, train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Callable[[], PreTrainedModel] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None, callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
): ):
if args is None: if args is None:
output_dir = "tmp_trainer" output_dir = "tmp_trainer"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment