Unverified Commit 53155b52 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Trainer: move Seq2SeqTrainer imports under the typing guard (#22401)

parent 0e708178
...@@ -14,39 +14,42 @@ ...@@ -14,39 +14,42 @@
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .data.data_collator import DataCollator
from .deepspeed import is_deepspeed_zero3_enabled from .deepspeed import is_deepspeed_zero3_enabled
from .generation.configuration_utils import GenerationConfig from .generation.configuration_utils import GenerationConfig
from .modeling_utils import PreTrainedModel
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer import Trainer from .trainer import Trainer
from .trainer_callback import TrainerCallback
from .trainer_utils import EvalPrediction, PredictionOutput
from .training_args import TrainingArguments
from .utils import logging from .utils import logging
if TYPE_CHECKING:
from .data.data_collator import DataCollator
from .modeling_utils import PreTrainedModel
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import TrainerCallback
from .trainer_utils import EvalPrediction, PredictionOutput
from .training_args import TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer): class Seq2SeqTrainer(Trainer):
def __init__( def __init__(
self, self,
model: Union[PreTrainedModel, nn.Module] = None, model: Union["PreTrainedModel", nn.Module] = None,
args: TrainingArguments = None, args: "TrainingArguments" = None,
data_collator: Optional[DataCollator] = None, data_collator: Optional["DataCollator"] = None,
train_dataset: Optional[Dataset] = None, train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None, tokenizer: Optional["PreTrainedTokenizerBase"] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None, model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
): ):
...@@ -161,7 +164,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -161,7 +164,7 @@ class Seq2SeqTrainer(Trainer):
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "test", metric_key_prefix: str = "test",
**gen_kwargs, **gen_kwargs,
) -> PredictionOutput: ) -> "PredictionOutput":
""" """
Run prediction and returns predictions and potential metrics. Run prediction and returns predictions and potential metrics.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment