"...resnet50_tensorflow.git" did not exist on "ec7265be6f1e5708747dcd89c87bb6aea3aec9b9"
Unverified Commit 5c4eb4b1 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fixing FLOPS merge by checking if torch is available (#7013)



* Should check if `torch` is available

* fixed samples_count error, distributed_concat arguments

* style

* Import torch at beginning of file
Co-authored-by: default avatarTevenLeScao <teven.lescao@gmail.com>
parent 01d340ad
...@@ -1315,8 +1315,6 @@ class Trainer: ...@@ -1315,8 +1315,6 @@ class Trainer:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None: if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist() eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
if samples_count is not None:
samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist())
# Finally, turn the aggregated tensors into numpy arrays. # Finally, turn the aggregated tensors into numpy arrays.
if preds is not None: if preds is not None:
......
...@@ -2,12 +2,15 @@ import random ...@@ -2,12 +2,15 @@ import random
from typing import Any, Dict, List, NamedTuple, Optional, Union from typing import Any, Dict, List, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch
from .file_utils import is_tf_available, is_torch_available from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum from .tokenization_utils_base import ExplicitEnum
if is_torch_available():
import torch
def set_seed(seed: int): def set_seed(seed: int):
""" """
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
...@@ -129,9 +132,9 @@ default_hp_space = { ...@@ -129,9 +132,9 @@ default_hp_space = {
} }
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor: def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
assert self.args.local_rank != -1 if is_torch_available():
try:
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor) torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0) concat = torch.cat(output_tensors, dim=0)
...@@ -140,13 +143,17 @@ def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[ ...@@ -140,13 +143,17 @@ def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[
if num_total_examples is not None: if num_total_examples is not None:
concat = concat[:num_total_examples] concat = concat[:num_total_examples]
return concat return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_concat`")
def distributed_broadcast_scalars( def distributed_broadcast_scalars(
self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> torch.Tensor: ) -> "torch.Tensor":
assert self.args.local_rank != -1 if is_torch_available():
try:
tensorized_scalar = torch.Tensor(scalars).cuda() tensorized_scalar = torch.Tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar) torch.distributed.all_gather(output_tensors, tensorized_scalar)
...@@ -156,3 +163,7 @@ def distributed_broadcast_scalars( ...@@ -156,3 +163,7 @@ def distributed_broadcast_scalars(
if num_total_examples is not None: if num_total_examples is not None:
concat = concat[:num_total_examples] concat = concat[:num_total_examples]
return concat return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment