Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
34915bf8
Unverified
Commit
34915bf8
authored
Oct 14, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Oct 14, 2020
Browse files
[feat] OSS: adding a --profile option to the benchmark (#135)
parent
37c686e7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
22 deletions
+37
-22
benchmarks/oss.py
benchmarks/oss.py
+37
-22
No files found.
benchmarks/oss.py
View file @
34915bf8
...
@@ -9,6 +9,7 @@ from typing import Any, List, Optional, cast
...
@@ -9,6 +9,7 @@ from typing import Any, List, Optional, cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.autograd.profiler
as
profiler
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -49,6 +50,13 @@ def get_problem(rank, data_size, batch_size):
...
@@ -49,6 +50,13 @@ def get_problem(rank, data_size, batch_size):
return
model
,
dataloader
,
loss_fn
return
model
,
dataloader
,
loss_fn
class
OptimType
(
str
,
Enum
):
vanilla
=
"pytorch"
oss
=
"oss"
oss_sdp
=
"oss_sdp"
everyone
=
"everyone"
def
train
(
def
train
(
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
...
@@ -56,14 +64,13 @@ def train(
...
@@ -56,14 +64,13 @@ def train(
batch_size
:
int
=
32
,
batch_size
:
int
=
32
,
data_size
:
int
=
200
,
data_size
:
int
=
200
,
backend
:
str
=
"gloo"
,
backend
:
str
=
"gloo"
,
use_oss
:
bool
=
True
,
optim_type
:
OptimType
=
OptimType
.
vanilla
,
use_sdp
:
bool
=
False
,
profile
:
bool
=
False
,
check_regression
:
bool
=
True
,
check_regression
:
bool
=
True
,
reference_speed
:
float
=
-
1.0
,
reference_speed
:
float
=
-
1.0
,
reference_memory
:
float
=
-
1.0
,
reference_memory
:
float
=
-
1.0
,
reference_loss
:
float
=
-
1.0
,
reference_loss
:
float
=
-
1.0
,
):
):
assert
not
use_sdp
or
(
use_sdp
and
use_oss
),
"ShardedDataParallel requires OSS"
# DDP
# DDP
dist_init
(
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
)
dist_init
(
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
)
...
@@ -82,7 +89,7 @@ def train(
...
@@ -82,7 +89,7 @@ def train(
# Shard the optimizer
# Shard the optimizer
optimizer
:
Optional
[
torch
.
optim
.
Optimizer
]
=
None
optimizer
:
Optional
[
torch
.
optim
.
Optimizer
]
=
None
if
use
_sdp
:
if
optim_type
==
OptimType
.
oss
_sdp
:
ddp
=
ShardedDataParallel
(
ddp
=
ShardedDataParallel
(
module
=
model
,
module
=
model
,
optimizer
=
OPTIM
,
optimizer
=
OPTIM
,
...
@@ -97,7 +104,7 @@ def train(
...
@@ -97,7 +104,7 @@ def train(
model
=
DDP
(
model
,
device_ids
=
[
rank
],
find_unused_parameters
=
True
)
# type: ignore
model
=
DDP
(
model
,
device_ids
=
[
rank
],
find_unused_parameters
=
True
)
# type: ignore
optimizer
=
(
optimizer
=
(
OSS
(
params
=
model
.
parameters
(),
optim
=
OPTIM
,
lr
=
1e-4
,
momentum
=
0.9
)
OSS
(
params
=
model
.
parameters
(),
optim
=
OPTIM
,
lr
=
1e-4
,
momentum
=
0.9
)
if
use_
oss
if
optim_type
==
OptimType
.
oss
else
OPTIM
(
model
.
parameters
(),
lr
=
1e-4
,
momentum
=
0.9
)
else
OPTIM
(
model
.
parameters
(),
lr
=
1e-4
,
momentum
=
0.9
)
)
)
...
@@ -111,6 +118,7 @@ def train(
...
@@ -111,6 +118,7 @@ def train(
measurements
=
[]
measurements
=
[]
final_loss
:
Optional
[
float
]
=
-
1.0
final_loss
:
Optional
[
float
]
=
-
1.0
need_profiling
=
profile
for
epoch
in
range
(
num_epochs
):
for
epoch
in
range
(
num_epochs
):
epoch_start
=
time
.
monotonic
()
epoch_start
=
time
.
monotonic
()
...
@@ -124,16 +132,29 @@ def train(
...
@@ -124,16 +132,29 @@ def train(
loss
/=
world_size
loss
/=
world_size
loss
.
backward
()
loss
.
backward
()
if
use
_sdp
:
if
optim_type
==
OptimType
.
oss
_sdp
:
ddp
.
reduce
()
# Send the gradients to the appropriate shards
ddp
.
reduce
()
# Send the gradients to the appropriate shards
return
loss
return
loss
final_loss
=
optimizer
.
step
(
closure
)
if
need_profiling
:
print
(
"Profiling the run"
)
with
profiler
.
profile
(
use_cuda
=
True
)
as
prof
:
# type: ignore
with
profiler
.
record_function
(
"batch"
):
final_loss
=
optimizer
.
step
(
closure
)
print
(
"profiling done, final loss "
,
cast
(
float
,
final_loss
))
if
rank
==
0
:
prof
.
export_chrome_trace
(
f
"
{
optim_type
}
_trace.json"
)
need_profiling
=
False
# only profile once
else
:
final_loss
=
optimizer
.
step
(
closure
)
epoch_end
=
time
.
monotonic
()
epoch_end
=
time
.
monotonic
()
if
use_
oss
:
if
optim_type
==
OptimType
.
oss
:
# Check the checkpointing in the case of the OSS optimizer
# Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there
# Memory usage could spill over from there
optimizer
=
cast
(
OSS
,
optimizer
)
optimizer
=
cast
(
OSS
,
optimizer
)
...
@@ -160,7 +181,7 @@ def train(
...
@@ -160,7 +181,7 @@ def train(
std
=
math
.
sqrt
(
sum
(
diff
)
/
(
len
(
measurements
)
-
1
))
std
=
math
.
sqrt
(
sum
(
diff
)
/
(
len
(
measurements
)
-
1
))
print
(
f
"[
{
dist
.
get_rank
()
}
] : Mean speed:
{
mean
:.
2
f
}
+/-
{
std
:.
2
f
}
"
)
print
(
f
"[
{
dist
.
get_rank
()
}
] : Mean speed:
{
mean
:.
2
f
}
+/-
{
std
:.
2
f
}
"
)
if
use_oss
and
check_regression
and
dist
.
get_rank
()
==
0
:
if
check_regression
and
dist
.
get_rank
()
==
0
:
assert
(
mean
+
3.0
*
std
)
>
reference_speed
,
"Speed regression detected"
assert
(
mean
+
3.0
*
std
)
>
reference_speed
,
"Speed regression detected"
assert
max_memory
<
1.05
*
reference_memory
,
"Memory use regression detected"
assert
max_memory
<
1.05
*
reference_memory
,
"Memory use regression detected"
assert
abs
(
cast
(
float
,
final_loss
)
-
reference_loss
)
<
1e-3
,
"Loss regression detected"
assert
abs
(
cast
(
float
,
final_loss
)
-
reference_loss
)
<
1e-3
,
"Loss regression detected"
...
@@ -171,13 +192,6 @@ def train(
...
@@ -171,13 +192,6 @@ def train(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
class
OptimType
(
str
,
Enum
):
vanilla
=
"pytorch"
oss
=
"oss"
oss_sdp
=
"oss_sdp"
everyone
=
"everyone"
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the optimizer state sharding, on a typical computer vision workload"
description
=
"Benchmark the optimizer state sharding, on a typical computer vision workload"
)
)
...
@@ -193,6 +207,7 @@ if __name__ == "__main__":
...
@@ -193,6 +207,7 @@ if __name__ == "__main__":
"--optim_type"
,
type
=
OptimType
,
choices
=
[
o
.
value
for
o
in
OptimType
],
default
=
OptimType
.
everyone
"--optim_type"
,
type
=
OptimType
,
choices
=
[
o
.
value
for
o
in
OptimType
],
default
=
OptimType
.
everyone
)
)
parser
.
add_argument
(
"--gloo"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--gloo"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
default
=
False
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
f
"Benchmark arguments:
{
args
}
"
)
print
(
f
"Benchmark arguments:
{
args
}
"
)
...
@@ -209,8 +224,8 @@ if __name__ == "__main__":
...
@@ -209,8 +224,8 @@ if __name__ == "__main__":
args
.
batch_size
,
args
.
batch_size
,
args
.
data_size
,
args
.
data_size
,
backend
,
backend
,
False
,
# OSS
OptimType
.
vanilla
,
False
,
# SDP
args
.
profile
,
False
,
# no regression check
False
,
# no regression check
),
),
nprocs
=
args
.
world_size
,
nprocs
=
args
.
world_size
,
...
@@ -227,8 +242,8 @@ if __name__ == "__main__":
...
@@ -227,8 +242,8 @@ if __name__ == "__main__":
args
.
batch_size
,
args
.
batch_size
,
args
.
data_size
,
args
.
data_size
,
backend
,
backend
,
True
,
# OSS
OptimType
.
oss
,
False
,
# SDP
args
.
profile
,
args
.
check_regression
,
args
.
check_regression
,
args
.
reference_speed
,
args
.
reference_speed
,
args
.
reference_memory
,
args
.
reference_memory
,
...
@@ -248,8 +263,8 @@ if __name__ == "__main__":
...
@@ -248,8 +263,8 @@ if __name__ == "__main__":
args
.
batch_size
,
args
.
batch_size
,
args
.
data_size
,
args
.
data_size
,
backend
,
backend
,
True
,
# OSS
OptimType
.
oss_sdp
,
True
,
# SDP
args
.
profile
,
False
,
# FIXME: @lefaudeux - SDP should give the same results
False
,
# FIXME: @lefaudeux - SDP should give the same results
-
1
,
# Not checking SDP for speed regression for now, still slower than OSS
-
1
,
# Not checking SDP for speed regression for now, still slower than OSS
args
.
reference_memory
,
args
.
reference_memory
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment