Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8a49a748
Unverified
Commit
8a49a748
authored
Jan 21, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 21, 2021
Browse files
[feat] Enabling ViT in OSS benchmarks (#322)
parent
dd441e9d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
9 deletions
+23
-9
benchmarks/oss.py
benchmarks/oss.py
+22
-9
requirements-test.txt
requirements-test.txt
+1
-0
No files found.
benchmarks/oss.py
View file @
8a49a748
...
@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
torch.utils.data
import
BatchSampler
,
DataLoader
,
Sampler
from
torch.utils.data
import
BatchSampler
,
DataLoader
,
Sampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torchvision.datasets
import
MNIST
from
torchvision.datasets
import
MNIST
from
torchvision.transforms
import
ToTensor
from
torchvision.transforms
import
Compose
,
Resize
,
ToTensor
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
...
@@ -39,7 +39,11 @@ def dist_init(rank, world_size, backend):
...
@@ -39,7 +39,11 @@ def dist_init(rank, world_size, backend):
def
get_problem
(
rank
,
world_size
,
batch_size
,
device
,
model_name
:
str
):
def
get_problem
(
rank
,
world_size
,
batch_size
,
device
,
model_name
:
str
):
# Select the desired model on the fly
# Select the desired model on the fly
logging
.
info
(
f
"Using
{
model_name
}
for benchmarking"
)
logging
.
info
(
f
"Using
{
model_name
}
for benchmarking"
)
model
=
getattr
(
importlib
.
import_module
(
"torchvision.models"
),
model_name
)(
pretrained
=
False
).
to
(
device
)
try
:
model
=
getattr
(
importlib
.
import_module
(
"torchvision.models"
),
model_name
)(
pretrained
=
False
).
to
(
device
)
except
AttributeError
:
model
=
getattr
(
importlib
.
import_module
(
"timm.models"
),
model_name
)(
pretrained
=
False
).
to
(
device
)
# Data setup, duplicate the grey channels to get pseudo color
# Data setup, duplicate the grey channels to get pseudo color
def
collate
(
inputs
:
List
[
Any
]):
def
collate
(
inputs
:
List
[
Any
]):
...
@@ -48,7 +52,16 @@ def get_problem(rank, world_size, batch_size, device, model_name: str):
...
@@ -48,7 +52,16 @@ def get_problem(rank, world_size, batch_size, device, model_name: str):
"label"
:
torch
.
tensor
([
i
[
1
]
for
i
in
inputs
]).
to
(
device
),
"label"
:
torch
.
tensor
([
i
[
1
]
for
i
in
inputs
]).
to
(
device
),
}
}
dataset
=
MNIST
(
transform
=
ToTensor
(),
download
=
False
,
root
=
TEMPDIR
)
# Transforms
transforms
=
[]
if
model_name
.
startswith
(
"vit"
):
# ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
pic_size
=
int
(
model_name
.
split
(
"_"
)[
-
1
])
transforms
.
append
(
Resize
(
pic_size
))
transforms
.
append
(
ToTensor
())
dataset
=
MNIST
(
transform
=
Compose
(
transforms
),
download
=
False
,
root
=
TEMPDIR
)
sampler
:
Sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
)
sampler
:
Sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
)
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
=
True
)
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
=
True
)
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
collate
)
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
collate
)
...
@@ -88,7 +101,7 @@ def train(
...
@@ -88,7 +101,7 @@ def train(
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
benchmark
=
False
device
=
torch
.
device
(
"cpu"
)
if
args
.
cpu
else
torch
.
device
(
rank
)
device
=
torch
.
device
(
"cpu"
)
if
args
.
cpu
else
torch
.
device
(
rank
)
model
,
dataloader
,
loss_fn
=
get_problem
(
rank
,
args
.
world_size
,
args
.
batch_size
,
device
,
args
.
torchvision_
model
)
model
,
dataloader
,
loss_fn
=
get_problem
(
rank
,
args
.
world_size
,
args
.
batch_size
,
device
,
args
.
model
)
# Shard the optimizer
# Shard the optimizer
optimizer
:
Optional
[
torch
.
optim
.
Optimizer
]
=
None
optimizer
:
Optional
[
torch
.
optim
.
Optimizer
]
=
None
...
@@ -259,7 +272,7 @@ if __name__ == "__main__":
...
@@ -259,7 +272,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--gloo"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--gloo"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--
torchvision_
model"
,
type
=
str
,
help
=
"Any torchvision model name (str)"
,
default
=
"resnet101"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
help
=
"Any torchvision
or timm
model name (str)"
,
default
=
"resnet101"
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Display additional debug information"
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Display additional debug information"
)
parser
.
add_argument
(
"--amp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Activate torch AMP"
)
parser
.
add_argument
(
"--amp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Activate torch AMP"
)
...
@@ -293,8 +306,8 @@ if __name__ == "__main__":
...
@@ -293,8 +306,8 @@ if __name__ == "__main__":
if
args
.
optim_type
==
OptimType
.
vanilla
or
args
.
optim_type
==
OptimType
.
everyone
:
if
args
.
optim_type
==
OptimType
.
vanilla
or
args
.
optim_type
==
OptimType
.
everyone
:
logging
.
info
(
"
\n
*** Benchmark vanilla optimizer"
)
logging
.
info
(
"
\n
*** Benchmark vanilla optimizer"
)
mp
.
spawn
(
mp
.
spawn
(
train
,
train
,
# type: ignore
args
=
(
args
,
BACKEND
,
OptimType
.
vanilla
,
False
,
),
# no regression check
args
=
(
args
,
BACKEND
,
OptimType
.
vanilla
,
False
),
# no regression check
nprocs
=
args
.
world_size
,
nprocs
=
args
.
world_size
,
join
=
True
,
join
=
True
,
)
)
...
@@ -302,13 +315,13 @@ if __name__ == "__main__":
...
@@ -302,13 +315,13 @@ if __name__ == "__main__":
if
args
.
optim_type
==
OptimType
.
oss_ddp
or
args
.
optim_type
==
OptimType
.
everyone
:
if
args
.
optim_type
==
OptimType
.
oss_ddp
or
args
.
optim_type
==
OptimType
.
everyone
:
logging
.
info
(
"
\n
*** Benchmark OSS with DDP"
)
logging
.
info
(
"
\n
*** Benchmark OSS with DDP"
)
mp
.
spawn
(
mp
.
spawn
(
train
,
args
=
(
args
,
BACKEND
,
OptimType
.
oss_ddp
,
args
.
check_regression
),
nprocs
=
args
.
world_size
,
join
=
True
,
train
,
args
=
(
args
,
BACKEND
,
OptimType
.
oss_ddp
,
args
.
check_regression
),
nprocs
=
args
.
world_size
,
join
=
True
,
# type: ignore
)
)
if
args
.
optim_type
==
OptimType
.
oss_sharded_ddp
or
args
.
optim_type
==
OptimType
.
everyone
:
if
args
.
optim_type
==
OptimType
.
oss_sharded_ddp
or
args
.
optim_type
==
OptimType
.
everyone
:
logging
.
info
(
"
\n
*** Benchmark OSS with ShardedDDP"
)
logging
.
info
(
"
\n
*** Benchmark OSS with ShardedDDP"
)
mp
.
spawn
(
mp
.
spawn
(
train
,
train
,
# type: ignore
args
=
(
args
=
(
args
,
args
,
BACKEND
,
BACKEND
,
...
...
requirements-test.txt
View file @
8a49a748
...
@@ -12,3 +12,4 @@ torch >= 1.5.1
...
@@ -12,3 +12,4 @@ torch >= 1.5.1
torchvision >= 0.6.0
torchvision >= 0.6.0
# NOTE(msb) not a dependency but needed by torch
# NOTE(msb) not a dependency but needed by torch
numpy == 1.17.4
numpy == 1.17.4
timm == 0.3.4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment