Unverified Commit 0491715f authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Cache MNIST fetchs, use alternative URLs (#465)

parent 7a3199b1
...@@ -335,7 +335,6 @@ jobs: ...@@ -335,7 +335,6 @@ jobs:
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py38-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-py38-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_171 - <<: *install_dep_171
- save_cache: - save_cache:
...@@ -401,7 +400,7 @@ jobs: ...@@ -401,7 +400,7 @@ jobs:
test_list_file: test_list_file:
type: string type: string
default: "/dev/non_exist" default: "/dev/non_exist"
<<: *gpu <<: *gpu
working_directory: ~/fairscale working_directory: ~/fairscale
...@@ -537,6 +536,11 @@ jobs: ...@@ -537,6 +536,11 @@ jobs:
keys: keys:
- cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
- restore_cache:
keys:
- cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
- <<: *install_dep_171 - <<: *install_dep_171
- save_cache: - save_cache:
...@@ -556,6 +560,11 @@ jobs: ...@@ -556,6 +560,11 @@ jobs:
- <<: *run_oss_gloo - <<: *run_oss_gloo
- save_cache:
paths:
- /tmp/MNIST
key: cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
benchmarks_2: benchmarks_2:
...@@ -581,6 +590,12 @@ jobs: ...@@ -581,6 +590,12 @@ jobs:
keys: keys:
- cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
- restore_cache:
keys:
- cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
- <<: *install_dep_171 - <<: *install_dep_171
- save_cache: - save_cache:
...@@ -592,6 +607,11 @@ jobs: ...@@ -592,6 +607,11 @@ jobs:
- <<: *run_oss_benchmark - <<: *run_oss_benchmark
- save_cache:
paths:
- /tmp/MNIST
key: cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
workflows: workflows:
version: 2 version: 2
......
import logging
from pathlib import Path
import shutil
import tempfile
from torchvision.datasets import MNIST
TEMPDIR = tempfile.gettempdir()
def setup_cached_mnist():
done, tentatives = False, 0
while not done and tentatives < 5:
# Monkey patch the resource URLs to work around a possible blacklist
MNIST.resources = [
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/train-images-idx3-ubyte.gz",
"f68b3c2dcbeaaa9fbdd348bbdeb94873",
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/train-labels-idx1-ubyte.gz",
"d53e105ee54ea40749a09fcbcd1e9432",
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/t10k-images-idx3-ubyte.gz",
"9fb629c4189551a2d022fa330f9573f3",
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/t10k-labels-idx1-ubyte.gz",
"ec29112dd5afa0611ce80d1b7f02629c",
),
]
# This will automatically skip the download if the dataset is already there, and check the checksum
try:
_ = MNIST(transform=None, download=True, root=TEMPDIR)
done = True
except RuntimeError as e:
logging.warning(e)
mnist_root = Path(TEMPDIR + "/MNIST")
# Corrupted data, erase and restart
shutil.rmtree(str(mnist_root))
tentatives += 1
if done is False:
logging.error("Could not download MNIST dataset")
exit(-1)
else:
logging.info("Dataset downloaded")
...@@ -5,7 +5,6 @@ import argparse ...@@ -5,7 +5,6 @@ import argparse
from enum import Enum from enum import Enum
import importlib import importlib
import logging import logging
import shutil
import tempfile import tempfile
import time import time
from typing import Any, List, Optional, cast from typing import Any, List, Optional, cast
...@@ -24,6 +23,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -24,6 +23,7 @@ from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Resize, ToTensor from torchvision.transforms import Compose, Resize, ToTensor
from benchmarks.datasets.mnist import setup_cached_mnist
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
...@@ -302,23 +302,7 @@ if __name__ == "__main__": ...@@ -302,23 +302,7 @@ if __name__ == "__main__":
BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo" BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
# Download dataset once for all processes # Download dataset once for all processes
dataset, tentatives = None, 0 setup_cached_mnist()
while dataset is None and tentatives < 5:
try:
dataset = MNIST(transform=None, download=True, root=TEMPDIR)
except (RuntimeError, EOFError) as e:
if isinstance(e, RuntimeError):
# Corrupted data, erase and restart
shutil.rmtree(TEMPDIR + "/MNIST")
logging.warning("Failed loading dataset: %s " % e)
tentatives += 1
if dataset is None:
logging.error("Could not download MNIST dataset")
exit(-1)
else:
logging.info("Dataset downloaded")
# Benchmark the different configurations, via multiple processes # Benchmark the different configurations, via multiple processes
if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment