Unverified Commit 0491715f authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Cache MNIST fetchs, use alternative URLs (#465)

parent 7a3199b1
......@@ -335,7 +335,6 @@ jobs:
- restore_cache:
keys:
- cache-key-cpu-py38-171-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_171
- save_cache:
......@@ -401,7 +400,7 @@ jobs:
test_list_file:
type: string
default: "/dev/non_exist"
<<: *gpu
working_directory: ~/fairscale
......@@ -537,6 +536,11 @@ jobs:
keys:
- cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
- restore_cache:
keys:
- cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
- <<: *install_dep_171
- save_cache:
......@@ -556,6 +560,11 @@ jobs:
- <<: *run_oss_gloo
- save_cache:
paths:
- /tmp/MNIST
key: cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
benchmarks_2:
......@@ -581,6 +590,12 @@ jobs:
keys:
- cache-key-benchmarks-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
# Cache the MNIST directory that contains benchmark data
- restore_cache:
keys:
- cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
- <<: *install_dep_171
- save_cache:
......@@ -592,6 +607,11 @@ jobs:
- <<: *run_oss_benchmark
- save_cache:
paths:
- /tmp/MNIST
key: cache-key-benchmark-MNIST-{{ checksum "benchmarks/datasets/mnist.py"}}
workflows:
version: 2
......
import logging
from pathlib import Path
import shutil
import tempfile
from torchvision.datasets import MNIST
TEMPDIR = tempfile.gettempdir()
def setup_cached_mnist():
done, tentatives = False, 0
while not done and tentatives < 5:
# Monkey patch the resource URLs to work around a possible blacklist
MNIST.resources = [
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/train-images-idx3-ubyte.gz",
"f68b3c2dcbeaaa9fbdd348bbdeb94873",
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/train-labels-idx1-ubyte.gz",
"d53e105ee54ea40749a09fcbcd1e9432",
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/t10k-images-idx3-ubyte.gz",
"9fb629c4189551a2d022fa330f9573f3",
),
(
"https://github.com/blefaudeux/mnist_dataset/raw/main/t10k-labels-idx1-ubyte.gz",
"ec29112dd5afa0611ce80d1b7f02629c",
),
]
# This will automatically skip the download if the dataset is already there, and check the checksum
try:
_ = MNIST(transform=None, download=True, root=TEMPDIR)
done = True
except RuntimeError as e:
logging.warning(e)
mnist_root = Path(TEMPDIR + "/MNIST")
# Corrupted data, erase and restart
shutil.rmtree(str(mnist_root))
tentatives += 1
if done is False:
logging.error("Could not download MNIST dataset")
exit(-1)
else:
logging.info("Dataset downloaded")
......@@ -5,7 +5,6 @@ import argparse
from enum import Enum
import importlib
import logging
import shutil
import tempfile
import time
from typing import Any, List, Optional, cast
......@@ -24,6 +23,7 @@ from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Resize, ToTensor
from benchmarks.datasets.mnist import setup_cached_mnist
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
......@@ -302,23 +302,7 @@ if __name__ == "__main__":
BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
# Download dataset once for all processes
dataset, tentatives = None, 0
while dataset is None and tentatives < 5:
try:
dataset = MNIST(transform=None, download=True, root=TEMPDIR)
except (RuntimeError, EOFError) as e:
if isinstance(e, RuntimeError):
# Corrupted data, erase and restart
shutil.rmtree(TEMPDIR + "/MNIST")
logging.warning("Failed loading dataset: %s " % e)
tentatives += 1
if dataset is None:
logging.error("Could not download MNIST dataset")
exit(-1)
else:
logging.info("Dataset downloaded")
setup_cached_mnist()
# Benchmark the different configurations, via multiple processes
if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment