Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
0491715f
Unverified
Commit
0491715f
authored
Mar 03, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 03, 2021
Browse files
[fix] Cache MNIST fetchs, use alternative URLs (#465)
parent
7a3199b1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
20 deletions
+74
-20
.circleci/config.yml
.circleci/config.yml
+22
-2
benchmarks/datasets/mnist.py
benchmarks/datasets/mnist.py
+50
-0
benchmarks/oss.py
benchmarks/oss.py
+2
-18
No files found.
.circleci/config.yml
View file @
0491715f
...
@@ -335,7 +335,6 @@ jobs:
...
@@ -335,7 +335,6 @@ jobs:
-
restore_cache
:
-
restore_cache
:
keys
:
keys
:
-
cache-key-cpu-py38-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
cache-key-cpu-py38-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_dep_171
-
<<
:
*install_dep_171
-
save_cache
:
-
save_cache
:
...
@@ -401,7 +400,7 @@ jobs:
...
@@ -401,7 +400,7 @@ jobs:
test_list_file
:
test_list_file
:
type
:
string
type
:
string
default
:
"
/dev/non_exist"
default
:
"
/dev/non_exist"
<<
:
*gpu
<<
:
*gpu
working_directory
:
~/fairscale
working_directory
:
~/fairscale
...
@@ -537,6 +536,11 @@ jobs:
...
@@ -537,6 +536,11 @@ jobs:
keys
:
keys
:
-
cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
-
restore_cache
:
keys
:
-
cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
-
<<
:
*install_dep_171
-
<<
:
*install_dep_171
-
save_cache
:
-
save_cache
:
...
@@ -556,6 +560,11 @@ jobs:
...
@@ -556,6 +560,11 @@ jobs:
-
<<
:
*run_oss_gloo
-
<<
:
*run_oss_gloo
-
save_cache
:
paths
:
-
/tmp/MNIST
key
:
cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
benchmarks_2
:
benchmarks_2
:
...
@@ -581,6 +590,12 @@ jobs:
...
@@ -581,6 +590,12 @@ jobs:
keys
:
keys
:
-
cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
-
restore_cache
:
keys
:
-
cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
-
<<
:
*install_dep_171
-
<<
:
*install_dep_171
-
save_cache
:
-
save_cache
:
...
@@ -592,6 +607,11 @@ jobs:
...
@@ -592,6 +607,11 @@ jobs:
-
<<
:
*run_oss_benchmark
-
<<
:
*run_oss_benchmark
-
save_cache
:
paths
:
-
/tmp/MNIST
key
:
cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
workflows
:
workflows
:
version
:
2
version
:
2
...
...
benchmarks/datasets/mnist.py
0 → 100644
View file @
0491715f
import
logging
from
pathlib
import
Path
import
shutil
import
tempfile
from
torchvision.datasets
import
MNIST
TEMPDIR
=
tempfile
.
gettempdir
()
def
setup_cached_mnist
():
done
,
tentatives
=
False
,
0
while
not
done
and
tentatives
<
5
:
# Monkey patch the resource URLs to work around a possible blacklist
MNIST
.
resources
=
[
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/train-images-idx3-ubyte.gz"
,
"f68b3c2dcbeaaa9fbdd348bbdeb94873"
,
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/train-labels-idx1-ubyte.gz"
,
"d53e105ee54ea40749a09fcbcd1e9432"
,
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/t10k-images-idx3-ubyte.gz"
,
"9fb629c4189551a2d022fa330f9573f3"
,
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/t10k-labels-idx1-ubyte.gz"
,
"ec29112dd5afa0611ce80d1b7f02629c"
,
),
]
# This will automatically skip the download if the dataset is already there, and check the checksum
try
:
_
=
MNIST
(
transform
=
None
,
download
=
True
,
root
=
TEMPDIR
)
done
=
True
except
RuntimeError
as
e
:
logging
.
warning
(
e
)
mnist_root
=
Path
(
TEMPDIR
+
"/MNIST"
)
# Corrupted data, erase and restart
shutil
.
rmtree
(
str
(
mnist_root
))
tentatives
+=
1
if
done
is
False
:
logging
.
error
(
"Could not download MNIST dataset"
)
exit
(
-
1
)
else
:
logging
.
info
(
"Dataset downloaded"
)
benchmarks/oss.py
View file @
0491715f
...
@@ -5,7 +5,6 @@ import argparse
...
@@ -5,7 +5,6 @@ import argparse
from
enum
import
Enum
from
enum
import
Enum
import
importlib
import
importlib
import
logging
import
logging
import
shutil
import
tempfile
import
tempfile
import
time
import
time
from
typing
import
Any
,
List
,
Optional
,
cast
from
typing
import
Any
,
List
,
Optional
,
cast
...
@@ -24,6 +23,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -24,6 +23,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torchvision.datasets
import
MNIST
from
torchvision.datasets
import
MNIST
from
torchvision.transforms
import
Compose
,
Resize
,
ToTensor
from
torchvision.transforms
import
Compose
,
Resize
,
ToTensor
from
benchmarks.datasets.mnist
import
setup_cached_mnist
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
...
@@ -302,23 +302,7 @@ if __name__ == "__main__":
...
@@ -302,23 +302,7 @@ if __name__ == "__main__":
BACKEND
=
"nccl"
if
(
not
args
.
gloo
or
not
torch
.
cuda
.
is_available
())
and
not
args
.
cpu
else
"gloo"
BACKEND
=
"nccl"
if
(
not
args
.
gloo
or
not
torch
.
cuda
.
is_available
())
and
not
args
.
cpu
else
"gloo"
# Download dataset once for all processes
# Download dataset once for all processes
dataset
,
tentatives
=
None
,
0
setup_cached_mnist
()
while
dataset
is
None
and
tentatives
<
5
:
try
:
dataset
=
MNIST
(
transform
=
None
,
download
=
True
,
root
=
TEMPDIR
)
except
(
RuntimeError
,
EOFError
)
as
e
:
if
isinstance
(
e
,
RuntimeError
):
# Corrupted data, erase and restart
shutil
.
rmtree
(
TEMPDIR
+
"/MNIST"
)
logging
.
warning
(
"Failed loading dataset: %s "
%
e
)
tentatives
+=
1
if
dataset
is
None
:
logging
.
error
(
"Could not download MNIST dataset"
)
exit
(
-
1
)
else
:
logging
.
info
(
"Dataset downloaded"
)
# Benchmark the different configurations, via multiple processes
# Benchmark the different configurations, via multiple processes
if
args
.
optim_type
==
OptimType
.
vanilla
or
args
.
optim_type
==
OptimType
.
everyone
:
if
args
.
optim_type
==
OptimType
.
vanilla
or
args
.
optim_type
==
OptimType
.
everyone
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment