from typing import Any, Dict, Union import torch from torch.utils.data.dataloader import DataLoader from transformers import Trainer class FunsdTrainer(Trainer): def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: """ Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and handling potential state. """ for k, v in inputs.items(): if hasattr(v, "to") and hasattr(v, "device"): inputs[k] = v.to(self.args.device) if self.args.past_index >= 0 and self._past is not None: inputs["mems"] = self._past return inputs