Unverified Commit 3427a039 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[cleanup] get 100% coverage on oss.py (#38)


authored-by: default avatarMandeep Singh Baines <msb@fb.com>
parent fffd3c76
...@@ -13,7 +13,7 @@ from torch.optim import SGD, Optimizer ...@@ -13,7 +13,7 @@ from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device from .utils import broadcast_object, recursive_copy_to_device
if TYPE_CHECKING: if TYPE_CHECKING: # pragma: no cover
from torch.optim.optimizer import _params_t from torch.optim.optimizer import _params_t
else: else:
_params_t = Any _params_t = Any
......
...@@ -53,7 +53,7 @@ def broadcast_object( ...@@ -53,7 +53,7 @@ def broadcast_object(
if dist.get_rank() == src_rank: if dist.get_rank() == src_rank:
# Emit data # Emit data
buffer = io.BytesIO() buffer = io.BytesIO()
torch.save(obj, buffer) # type: ignore torch.save(obj, buffer)
data = bytearray(buffer.getbuffer()) data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device) length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device) data_send_tensor = torch.ByteTensor(data).to(dist_device)
...@@ -66,5 +66,5 @@ def broadcast_object( ...@@ -66,5 +66,5 @@ def broadcast_object(
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device) data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device) # type: ignore obj = torch.load(buffer, map_location=dist_device)
return obj return obj
from typing import Any, BinaryIO, Union
def save(obj, f: Union[str, BinaryIO]) -> None: ...
def load(f: Union[str, BinaryIO], map_location) -> Any: ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment