Unverified Commit 3427a039 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[cleanup] get 100% coverage on oss.py (#38)


authored-by: default avatarMandeep Singh Baines <msb@fb.com>
parent fffd3c76
......@@ -13,7 +13,7 @@ from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from torch.optim.optimizer import _params_t
else:
_params_t = Any
......
......@@ -53,7 +53,7 @@ def broadcast_object(
if dist.get_rank() == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer) # type: ignore
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(dist_device)
data_send_tensor = torch.ByteTensor(data).to(dist_device)
......@@ -66,5 +66,5 @@ def broadcast_object(
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=dist_device)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device) # type: ignore
obj = torch.load(buffer, map_location=dist_device)
return obj
from typing import Any, BinaryIO, Union
def save(obj, f: Union[str, BinaryIO]) -> None: ...
def load(f: Union[str, BinaryIO], map_location) -> Any: ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment