Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
65ca68a9
Unverified
Commit
65ca68a9
authored
Jan 27, 2021
by
Siddharth Goyal
Committed by
GitHub
Jan 27, 2021
Browse files
[fix] examples: fix naming style of helper functions (#334)
parent
73221557
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
19 additions
and
19 deletions
+19
-19
examples/helpers.py
examples/helpers.py
+3
-3
examples/tutorial_oss.py
examples/tutorial_oss.py
+4
-4
examples/tutorial_pipe.py
examples/tutorial_pipe.py
+4
-4
examples/tutorial_pipe_multiprocess.py
examples/tutorial_pipe_multiprocess.py
+4
-4
examples/tutorial_pipe_rpc.py
examples/tutorial_pipe_rpc.py
+4
-4
No files found.
examples/helpers.py
View file @
65ca68a9
...
@@ -10,13 +10,13 @@ def dist_init(rank, world_size):
...
@@ -10,13 +10,13 @@ def dist_init(rank, world_size):
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
"tcp://localhost:29501"
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
"tcp://localhost:29501"
,
rank
=
rank
,
world_size
=
world_size
)
def
get
M
odel
():
def
get
_m
odel
():
return
nn
.
Sequential
(
torch
.
nn
.
Linear
(
10
,
10
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Linear
(
10
,
5
))
return
nn
.
Sequential
(
torch
.
nn
.
Linear
(
10
,
10
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Linear
(
10
,
5
))
def
get
D
ata
(
n_batches
=
1
):
def
get
_d
ata
(
n_batches
=
1
):
return
[(
torch
.
randn
(
20
,
10
),
torch
.
randint
(
0
,
2
,
size
=
(
20
,
1
)).
squeeze
())
for
i
in
range
(
n_batches
)]
return
[(
torch
.
randn
(
20
,
10
),
torch
.
randint
(
0
,
2
,
size
=
(
20
,
1
)).
squeeze
())
for
i
in
range
(
n_batches
)]
def
get
L
oss
F
un
():
def
get
_l
oss
_f
un
():
return
F
.
nll_loss
return
F
.
nll_loss
examples/tutorial_oss.py
View file @
65ca68a9
import
time
import
time
from
typing
import
Optional
,
Union
,
cast
from
typing
import
Optional
,
Union
,
cast
from
helpers
import
dist_init
,
get
D
ata
,
get
L
oss
F
un
,
get
M
odel
from
helpers
import
dist_init
,
get
_d
ata
,
get
_l
oss
_f
un
,
get
_m
odel
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
...
@@ -21,9 +21,9 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
...
@@ -21,9 +21,9 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
device
=
torch
.
device
(
"cpu"
)
if
DEVICE
==
"cpu"
else
rank
# type:ignore
device
=
torch
.
device
(
"cpu"
)
if
DEVICE
==
"cpu"
else
rank
# type:ignore
# Problem statement
# Problem statement
model
=
get
M
odel
().
to
(
device
)
model
=
get
_m
odel
().
to
(
device
)
dataloader
=
get
D
ata
(
n_batches
=
1
)
dataloader
=
get
_d
ata
(
n_batches
=
1
)
loss_fn
=
get
L
oss
F
un
()
loss_fn
=
get
_l
oss
_f
un
()
optimizer
:
Optional
[
Union
[
OSS
,
torch
.
optim
.
SGD
]]
=
None
optimizer
:
Optional
[
Union
[
OSS
,
torch
.
optim
.
SGD
]]
=
None
...
...
examples/tutorial_pipe.py
View file @
65ca68a9
from
helpers
import
get
D
ata
,
get
L
oss
F
un
,
get
M
odel
from
helpers
import
get
_d
ata
,
get
_l
oss
_f
un
,
get
_m
odel
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
...
@@ -7,9 +7,9 @@ import fairscale
...
@@ -7,9 +7,9 @@ import fairscale
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
RANK
=
0
# example
RANK
=
0
# example
model
=
get
M
odel
()
model
=
get
_m
odel
()
data
,
target
=
get
D
ata
()[
0
]
data
,
target
=
get
_d
ata
()[
0
]
loss_fn
=
get
L
oss
F
un
()
loss_fn
=
get
_l
oss
_f
un
()
model
=
fairscale
.
nn
.
Pipe
(
model
,
balance
=
[
2
,
1
])
model
=
fairscale
.
nn
.
Pipe
(
model
,
balance
=
[
2
,
1
])
...
...
examples/tutorial_pipe_multiprocess.py
View file @
65ca68a9
import
os
import
os
from
helpers
import
dist_init
,
get
D
ata
,
get
L
oss
F
un
,
get
M
odel
from
helpers
import
dist_init
,
get
_d
ata
,
get
_l
oss
_f
un
,
get
_m
odel
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
...
@@ -20,9 +20,9 @@ def run(rank, world_size):
...
@@ -20,9 +20,9 @@ def run(rank, world_size):
dist
.
rpc
.
init_rpc
(
f
"worker
{
rank
}
"
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
rpc
.
init_rpc
(
f
"worker
{
rank
}
"
,
rank
=
rank
,
world_size
=
world_size
)
initialize_model_parallel
(
1
,
world_size
)
initialize_model_parallel
(
1
,
world_size
)
model
=
get
M
odel
()
model
=
get
_m
odel
()
data
,
target
=
get
D
ata
()[
0
]
data
,
target
=
get
_d
ata
()[
0
]
loss_fn
=
get
L
oss
F
un
()
loss_fn
=
get
_l
oss
_f
un
()
device
=
torch
.
device
(
"cuda"
,
rank
)
if
DEVICE
==
"cuda"
else
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cuda"
,
rank
)
if
DEVICE
==
"cuda"
else
torch
.
device
(
"cpu"
)
...
...
examples/tutorial_pipe_rpc.py
View file @
65ca68a9
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
os
import
os
from
helpers
import
dist_init
,
get
D
ata
,
get
L
oss
F
un
,
get
M
odel
from
helpers
import
dist_init
,
get
_d
ata
,
get
_l
oss
_f
un
,
get
_m
odel
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torch_pg
import
torch_pg
...
@@ -37,9 +37,9 @@ def run(rank, world_size):
...
@@ -37,9 +37,9 @@ def run(rank, world_size):
torch
.
distributed
.
rpc
.
shutdown
()
torch
.
distributed
.
rpc
.
shutdown
()
return
return
model
=
get
M
odel
()
model
=
get
_m
odel
()
data
,
target
=
get
D
ata
()[
0
]
data
,
target
=
get
_d
ata
()[
0
]
loss_fn
=
get
L
oss
F
un
()
loss_fn
=
get
_l
oss
_f
un
()
device
=
torch
.
device
(
"cuda"
,
rank
)
device
=
torch
.
device
(
"cuda"
,
rank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment