Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8a49a748
Unverified
Commit
8a49a748
authored
Jan 21, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 21, 2021
Browse files
[feat] Enabling ViT in OSS benchmarks (#322)
parent
dd441e9d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
9 deletions
+23
-9
benchmarks/oss.py
benchmarks/oss.py
+22
-9
requirements-test.txt
requirements-test.txt
+1
-0
No files found.
benchmarks/oss.py
View file @
8a49a748
...
@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
torch.utils.data
import
BatchSampler
,
DataLoader
,
Sampler
from
torch.utils.data
import
BatchSampler
,
DataLoader
,
Sampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torchvision.datasets
import
MNIST
from
torchvision.datasets
import
MNIST
from
torchvision.transforms
import
ToTensor
from
torchvision.transforms
import
Compose
,
Resize
,
ToTensor
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
...
@@ -39,7 +39,11 @@ def dist_init(rank, world_size, backend):
...
@@ -39,7 +39,11 @@ def dist_init(rank, world_size, backend):
def
get_problem
(
rank
,
world_size
,
batch_size
,
device
,
model_name
:
str
):
def
get_problem
(
rank
,
world_size
,
batch_size
,
device
,
model_name
:
str
):
# Select the desired model on the fly
# Select the desired model on the fly
logging
.
info
(
f
"Using
{
model_name
}
for benchmarking"
)
logging
.
info
(
f
"Using
{
model_name
}
for benchmarking"
)
model
=
getattr
(
importlib
.
import_module
(
"torchvision.models"
),
model_name
)(
pretrained
=
False
).
to
(
device
)
try
:
model
=
getattr
(
importlib
.
import_module
(
"torchvision.models"
),
model_name
)(
pretrained
=
False
).
to
(
device
)
except
AttributeError
:
model
=
getattr
(
importlib
.
import_module
(
"timm.models"
),
model_name
)(
pretrained
=
False
).
to
(
device
)
# Data setup, duplicate the grey channels to get pseudo color
# Data setup, duplicate the grey channels to get pseudo color
def
collate
(
inputs
:
List
[
Any
]):
def
collate
(
inputs
:
List
[
Any
]):
...
@@ -48,7 +52,16 @@ def get_problem(rank, world_size, batch_size, device, model_name: str):
...
@@ -48,7 +52,16 @@ def get_problem(rank, world_size, batch_size, device, model_name: str):
"label"
:
torch
.
tensor
([
i
[
1
]
for
i
in
inputs
]).
to
(
device
),
"label"
:
torch
.
tensor
([
i
[
1
]
for
i
in
inputs
]).
to
(
device
),
}
}
dataset
=
MNIST
(
transform
=
ToTensor
(),
download
=
False
,
root
=
TEMPDIR
)
# Transforms
transforms
=
[]
if
model_name
.
startswith
(
"vit"
):
# ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
pic_size
=
int
(
model_name
.
split
(
"_"
)[
-
1
])
transforms
.
append
(
Resize
(
pic_size
))
transforms
.
append
(
ToTensor
())
dataset
=
MNIST
(
transform
=
Compose
(
transforms
),
download
=
False
,
root
=
TEMPDIR
)
sampler
:
Sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
)
sampler
:
Sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
)
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
=
True
)
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
=
True
)
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
collate
)
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
collate
)
...
@@ -88,7 +101,7 @@ def train(
...
@@ -88,7 +101,7 @@ def train(
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
benchmark
=
False
device
=
torch
.
device
(
"cpu"
)
if
args
.
cpu
else
torch
.
device
(
rank
)
device
=
torch
.
device
(
"cpu"
)
if
args
.
cpu
else
torch
.
device
(
rank
)
model
,
dataloader
,
loss_fn
=
get_problem
(
rank
,
args
.
world_size
,
args
.
batch_size
,
device
,
args
.
torchvision_
model
)
model
,
dataloader
,
loss_fn
=
get_problem
(
rank
,
args
.
world_size
,
args
.
batch_size
,
device
,
args
.
model
)
# Shard the optimizer
# Shard the optimizer
optimizer
:
Optional
[
torch
.
optim
.
Optimizer
]
=
None
optimizer
:
Optional
[
torch
.
optim
.
Optimizer
]
=
None
...
@@ -259,7 +272,7 @@ if __name__ == "__main__":
...
@@ -259,7 +272,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--gloo"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--gloo"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--
torchvision_
model"
,
type
=
str
,
help
=
"Any torchvision model name (str)"
,
default
=
"resnet101"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
help
=
"Any torchvision
or timm
model name (str)"
,
default
=
"resnet101"
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Display additional debug information"
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Display additional debug information"
)
parser
.
add_argument
(
"--amp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Activate torch AMP"
)
parser
.
add_argument
(
"--amp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Activate torch AMP"
)
...
@@ -293,8 +306,8 @@ if __name__ == "__main__":
...
@@ -293,8 +306,8 @@ if __name__ == "__main__":
if
args
.
optim_type
==
OptimType
.
vanilla
or
args
.
optim_type
==
OptimType
.
everyone
:
if
args
.
optim_type
==
OptimType
.
vanilla
or
args
.
optim_type
==
OptimType
.
everyone
:
logging
.
info
(
"
\n
*** Benchmark vanilla optimizer"
)
logging
.
info
(
"
\n
*** Benchmark vanilla optimizer"
)
mp
.
spawn
(
mp
.
spawn
(
train
,
train
,
# type: ignore
args
=
(
args
,
BACKEND
,
OptimType
.
vanilla
,
False
,
),
# no regression check
args
=
(
args
,
BACKEND
,
OptimType
.
vanilla
,
False
),
# no regression check
nprocs
=
args
.
world_size
,
nprocs
=
args
.
world_size
,
join
=
True
,
join
=
True
,
)
)
...
@@ -302,13 +315,13 @@ if __name__ == "__main__":
...
@@ -302,13 +315,13 @@ if __name__ == "__main__":
if
args
.
optim_type
==
OptimType
.
oss_ddp
or
args
.
optim_type
==
OptimType
.
everyone
:
if
args
.
optim_type
==
OptimType
.
oss_ddp
or
args
.
optim_type
==
OptimType
.
everyone
:
logging
.
info
(
"
\n
*** Benchmark OSS with DDP"
)
logging
.
info
(
"
\n
*** Benchmark OSS with DDP"
)
mp
.
spawn
(
mp
.
spawn
(
train
,
args
=
(
args
,
BACKEND
,
OptimType
.
oss_ddp
,
args
.
check_regression
),
nprocs
=
args
.
world_size
,
join
=
True
,
train
,
args
=
(
args
,
BACKEND
,
OptimType
.
oss_ddp
,
args
.
check_regression
),
nprocs
=
args
.
world_size
,
join
=
True
,
# type: ignore
)
)
if
args
.
optim_type
==
OptimType
.
oss_sharded_ddp
or
args
.
optim_type
==
OptimType
.
everyone
:
if
args
.
optim_type
==
OptimType
.
oss_sharded_ddp
or
args
.
optim_type
==
OptimType
.
everyone
:
logging
.
info
(
"
\n
*** Benchmark OSS with ShardedDDP"
)
logging
.
info
(
"
\n
*** Benchmark OSS with ShardedDDP"
)
mp
.
spawn
(
mp
.
spawn
(
train
,
train
,
# type: ignore
args
=
(
args
=
(
args
,
args
,
BACKEND
,
BACKEND
,
...
...
requirements-test.txt
View file @
8a49a748
...
@@ -12,3 +12,4 @@ torch >= 1.5.1
...
@@ -12,3 +12,4 @@ torch >= 1.5.1
torchvision >= 0.6.0
torchvision >= 0.6.0
# NOTE(msb) not a dependency but needed by torch
# NOTE(msb) not a dependency but needed by torch
numpy == 1.17.4
numpy == 1.17.4
timm == 0.3.4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment