Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
0491715f
Unverified
Commit
0491715f
authored
Mar 03, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 03, 2021
Browse files
[fix] Cache MNIST fetchs, use alternative URLs (#465)
parent
7a3199b1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
20 deletions
+74
-20
.circleci/config.yml
.circleci/config.yml
+22
-2
benchmarks/datasets/mnist.py
benchmarks/datasets/mnist.py
+50
-0
benchmarks/oss.py
benchmarks/oss.py
+2
-18
No files found.
.circleci/config.yml
View file @
0491715f
...
...
@@ -335,7 +335,6 @@ jobs:
-
restore_cache
:
keys
:
-
cache-key-cpu-py38-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_dep_171
-
save_cache
:
...
...
@@ -401,7 +400,7 @@ jobs:
test_list_file
:
type
:
string
default
:
"
/dev/non_exist"
<<
:
*gpu
working_directory
:
~/fairscale
...
...
@@ -537,6 +536,11 @@ jobs:
keys
:
-
cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
-
restore_cache
:
keys
:
-
cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
-
<<
:
*install_dep_171
-
save_cache
:
...
...
@@ -556,6 +560,11 @@ jobs:
-
<<
:
*run_oss_gloo
-
save_cache
:
paths
:
-
/tmp/MNIST
key
:
cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
benchmarks_2
:
...
...
@@ -581,6 +590,12 @@ jobs:
keys
:
-
cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
-
restore_cache
:
keys
:
-
cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
-
<<
:
*install_dep_171
-
save_cache
:
...
...
@@ -592,6 +607,11 @@ jobs:
-
<<
:
*run_oss_benchmark
-
save_cache
:
paths
:
-
/tmp/MNIST
key
:
cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
workflows
:
version
:
2
...
...
benchmarks/datasets/mnist.py
0 → 100644
View file @
0491715f
import
logging
from
pathlib
import
Path
import
shutil
import
tempfile
from
torchvision.datasets
import
MNIST
TEMPDIR
=
tempfile
.
gettempdir
()
def
setup_cached_mnist
():
done
,
tentatives
=
False
,
0
while
not
done
and
tentatives
<
5
:
# Monkey patch the resource URLs to work around a possible blacklist
MNIST
.
resources
=
[
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/train-images-idx3-ubyte.gz"
,
"f68b3c2dcbeaaa9fbdd348bbdeb94873"
,
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/train-labels-idx1-ubyte.gz"
,
"d53e105ee54ea40749a09fcbcd1e9432"
,
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/t10k-images-idx3-ubyte.gz"
,
"9fb629c4189551a2d022fa330f9573f3"
,
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/t10k-labels-idx1-ubyte.gz"
,
"ec29112dd5afa0611ce80d1b7f02629c"
,
),
]
# This will automatically skip the download if the dataset is already there, and check the checksum
try
:
_
=
MNIST
(
transform
=
None
,
download
=
True
,
root
=
TEMPDIR
)
done
=
True
except
RuntimeError
as
e
:
logging
.
warning
(
e
)
mnist_root
=
Path
(
TEMPDIR
+
"/MNIST"
)
# Corrupted data, erase and restart
shutil
.
rmtree
(
str
(
mnist_root
))
tentatives
+=
1
if
done
is
False
:
logging
.
error
(
"Could not download MNIST dataset"
)
exit
(
-
1
)
else
:
logging
.
info
(
"Dataset downloaded"
)
benchmarks/oss.py
View file @
0491715f
...
...
@@ -5,7 +5,6 @@ import argparse
from
enum
import
Enum
import
importlib
import
logging
import
shutil
import
tempfile
import
time
from
typing
import
Any
,
List
,
Optional
,
cast
...
...
@@ -24,6 +23,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torchvision.datasets
import
MNIST
from
torchvision.transforms
import
Compose
,
Resize
,
ToTensor
from
benchmarks.datasets.mnist
import
setup_cached_mnist
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.optim
import
OSS
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
...
...
@@ -302,23 +302,7 @@ if __name__ == "__main__":
BACKEND
=
"nccl"
if
(
not
args
.
gloo
or
not
torch
.
cuda
.
is_available
())
and
not
args
.
cpu
else
"gloo"
# Download dataset once for all processes
dataset
,
tentatives
=
None
,
0
while
dataset
is
None
and
tentatives
<
5
:
try
:
dataset
=
MNIST
(
transform
=
None
,
download
=
True
,
root
=
TEMPDIR
)
except
(
RuntimeError
,
EOFError
)
as
e
:
if
isinstance
(
e
,
RuntimeError
):
# Corrupted data, erase and restart
shutil
.
rmtree
(
TEMPDIR
+
"/MNIST"
)
logging
.
warning
(
"Failed loading dataset: %s "
%
e
)
tentatives
+=
1
if
dataset
is
None
:
logging
.
error
(
"Could not download MNIST dataset"
)
exit
(
-
1
)
else
:
logging
.
info
(
"Dataset downloaded"
)
setup_cached_mnist
()
# Benchmark the different configurations, via multiple processes
if
args
.
optim_type
==
OptimType
.
vanilla
or
args
.
optim_type
==
OptimType
.
everyone
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment