Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
3427a039
Unverified
Commit
3427a039
authored
Aug 13, 2020
by
msbaines
Committed by
GitHub
Aug 13, 2020
Browse files
[cleanup] get 100% coverage on oss.py (#38)
authored-by:
Mandeep Singh Baines
<
msb@fb.com
>
parent
fffd3c76
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
3 deletions
+8
-3
fairscale/optim/oss.py
fairscale/optim/oss.py
+1
-1
fairscale/optim/utils.py
fairscale/optim/utils.py
+2
-2
stubs/torch/serialization.pyi
stubs/torch/serialization.pyi
+5
-0
No files found.
fairscale/optim/oss.py
View file @
3427a039
...
...
@@ -13,7 +13,7 @@ from torch.optim import SGD, Optimizer
from
.utils
import
broadcast_object
,
recursive_copy_to_device
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
# pragma: no cover
from
torch.optim.optimizer
import
_params_t
else
:
_params_t
=
Any
...
...
fairscale/optim/utils.py
View file @
3427a039
...
...
@@ -53,7 +53,7 @@ def broadcast_object(
if
dist
.
get_rank
()
==
src_rank
:
# Emit data
buffer
=
io
.
BytesIO
()
torch
.
save
(
obj
,
buffer
)
# type: ignore
torch
.
save
(
obj
,
buffer
)
data
=
bytearray
(
buffer
.
getbuffer
())
length_tensor
=
torch
.
LongTensor
([
len
(
data
)]).
to
(
dist_device
)
data_send_tensor
=
torch
.
ByteTensor
(
data
).
to
(
dist_device
)
...
...
@@ -66,5 +66,5 @@ def broadcast_object(
data_recv_tensor
=
torch
.
empty
([
int
(
length_tensor
.
item
())],
dtype
=
torch
.
uint8
,
device
=
dist_device
)
dist
.
broadcast
(
data_recv_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
buffer
=
io
.
BytesIO
(
data_recv_tensor
.
cpu
().
numpy
())
obj
=
torch
.
load
(
buffer
,
map_location
=
dist_device
)
# type: ignore
obj
=
torch
.
load
(
buffer
,
map_location
=
dist_device
)
return
obj
stubs/torch/serialization.pyi
0 → 100644
View file @
3427a039
from typing import Any, BinaryIO, Union
def save(obj, f: Union[str, BinaryIO]) -> None: ...
def load(f: Union[str, BinaryIO], map_location) -> Any: ...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment