Unverified Commit d2343076 authored by ahmadsharif1's avatar ahmadsharif1 Committed by GitHub
Browse files

Add pathlib.Path support for download utils (#8196)


Co-authored-by: default avatarAhmad Sharif <ahmads@fb.com>
Co-authored-by: default avatarBrizar <1500595+bmmtstb@users.noreply.github.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 2afb7faf
......@@ -6,6 +6,7 @@ cleanly ignored in FB internal test infra.
"""
import os
import pathlib
from urllib.error import URLError
import pytest
......@@ -13,7 +14,10 @@ import torchvision.datasets.utils as utils
class TestDatasetUtils:
def test_download_url(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "http://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, tmpdir)
......@@ -21,7 +25,10 @@ class TestDatasetUtils:
except URLError:
pytest.skip(f"could not download test file '{url}'")
def test_download_url_retry_http(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_retry_http(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "https://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, tmpdir)
......@@ -29,12 +36,18 @@ class TestDatasetUtils:
except URLError:
pytest.skip(f"could not download test file '{url}'")
def test_download_url_dont_exist(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_dont_exist(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
with pytest.raises(URLError):
utils.download_url(url, tmpdir)
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
id = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
......@@ -44,7 +57,7 @@ class TestDatasetUtils:
mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive")
utils.download_url(url, tmpdir, filename, md5)
mocked.assert_called_once_with(id, tmpdir, filename, md5)
mocked.assert_called_once_with(id, os.path.expanduser(tmpdir), filename, md5)
if __name__ == "__main__":
......
......@@ -15,7 +15,7 @@ import urllib.error
import urllib.request
import warnings
import zipfile
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
from urllib.parse import urlparse
import numpy as np
......@@ -104,7 +104,11 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
url: str,
root: Union[str, pathlib.Path],
filename: Optional[str] = None,
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None:
"""Download a file from a url and place it in root.
......@@ -118,7 +122,7 @@ def download_url(
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
fpath = os.fspath(os.path.join(root, filename))
os.makedirs(root, exist_ok=True)
......@@ -203,7 +207,9 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
return api_response, content
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
):
"""Download a Google Drive file from and place it in root.
Args:
......@@ -217,7 +223,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
fpath = os.fspath(os.path.join(root, filename))
os.makedirs(root, exist_ok=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment