Unverified Commit 3ccc2701 authored by Xiaohan Zou's avatar Xiaohan Zou Committed by GitHub
Browse files

Fix type annotations for `distributed_concat()` (#13746)

* Fix type annotations for `distributed_concat()`

* Use Any
parent e0d31a89
...@@ -25,7 +25,7 @@ import warnings ...@@ -25,7 +25,7 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from logging import StreamHandler from logging import StreamHandler
from typing import Dict, Iterator, List, Optional, Union from typing import Any, Dict, Iterator, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -157,7 +157,7 @@ def nested_xla_mesh_reduce(tensors, name): ...@@ -157,7 +157,7 @@ def nested_xla_mesh_reduce(tensors, name):
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor: def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> Any:
try: try:
if isinstance(tensor, (tuple, list)): if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment