Unverified Commit 9022ef02 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Only put tensors on a device (#5223)

* Only put tensors on a device

* Type hint and unpack list comprehension
parent 173528e3
...@@ -7,7 +7,7 @@ import shutil ...@@ -7,7 +7,7 @@ import shutil
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -570,11 +570,12 @@ class Trainer: ...@@ -570,11 +570,12 @@ class Trainer:
logger.info(output) logger.info(output)
def _training_step( def _training_step(
self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
) -> float: ) -> float:
model.train() model.train()
for k, v in inputs.items(): for k, v in inputs.items():
inputs[k] = v.to(self.args.device) if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc) loss = outputs[0] # model outputs are always tuple in transformers (see doc)
...@@ -758,7 +759,8 @@ class Trainer: ...@@ -758,7 +759,8 @@ class Trainer:
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
for k, v in inputs.items(): for k, v in inputs.items():
inputs[k] = v.to(self.args.device) if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment