Unverified Commit 9022ef02 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Only put tensors on a device (#5223)

* Only put tensors on a device

* Type hint and unpack list comprehension
parent 173528e3
......@@ -7,7 +7,7 @@ import shutil
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -570,10 +570,11 @@ class Trainer:
logger.info(output)
def _training_step(
self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
) -> float:
model.train()
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
outputs = model(**inputs)
......@@ -758,6 +759,7 @@ class Trainer:
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment