Unverified Commit 5c4eb4b1 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fixing FLOPS merge by checking if torch is available (#7013)



* Should check if `torch` is available

* fixed samples_count error, distributed_concat arguments

* style

* Import torch at beginning of file
Co-authored-by: default avatarTevenLeScao <teven.lescao@gmail.com>
parent 01d340ad
......@@ -1315,8 +1315,6 @@ class Trainer:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
if samples_count is not None:
samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist())
# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
......
......@@ -2,12 +2,15 @@ import random
from typing import Any, Dict, List, NamedTuple, Optional, Union
import numpy as np
import torch
from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum
if is_torch_available():
import torch
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
......@@ -129,9 +132,9 @@ default_hp_space = {
}
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor:
assert self.args.local_rank != -1
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
if is_torch_available():
try:
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
......@@ -140,13 +143,17 @@ def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_concat`")
def distributed_broadcast_scalars(
self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> torch.Tensor:
assert self.args.local_rank != -1
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> "torch.Tensor":
if is_torch_available():
try:
tensorized_scalar = torch.Tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
......@@ -156,3 +163,7 @@ def distributed_broadcast_scalars(
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment