Unverified Commit d2343076 authored by ahmadsharif1's avatar ahmadsharif1 Committed by GitHub
Browse files

Add pathlib.Path support for download utils (#8196)


Co-authored-by: default avatarAhmad Sharif <ahmads@fb.com>
Co-authored-by: default avatarBrizar <1500595+bmmtstb@users.noreply.github.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 2afb7faf
...@@ -6,6 +6,7 @@ cleanly ignored in FB internal test infra. ...@@ -6,6 +6,7 @@ cleanly ignored in FB internal test infra.
""" """
import os import os
import pathlib
from urllib.error import URLError from urllib.error import URLError
import pytest import pytest
...@@ -13,7 +14,10 @@ import torchvision.datasets.utils as utils ...@@ -13,7 +14,10 @@ import torchvision.datasets.utils as utils
class TestDatasetUtils: class TestDatasetUtils:
def test_download_url(self, tmpdir): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "http://github.com/pytorch/vision/archive/master.zip" url = "http://github.com/pytorch/vision/archive/master.zip"
try: try:
utils.download_url(url, tmpdir) utils.download_url(url, tmpdir)
...@@ -21,7 +25,10 @@ class TestDatasetUtils: ...@@ -21,7 +25,10 @@ class TestDatasetUtils:
except URLError: except URLError:
pytest.skip(f"could not download test file '{url}'") pytest.skip(f"could not download test file '{url}'")
def test_download_url_retry_http(self, tmpdir): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_retry_http(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "https://github.com/pytorch/vision/archive/master.zip" url = "https://github.com/pytorch/vision/archive/master.zip"
try: try:
utils.download_url(url, tmpdir) utils.download_url(url, tmpdir)
...@@ -29,12 +36,18 @@ class TestDatasetUtils: ...@@ -29,12 +36,18 @@ class TestDatasetUtils:
except URLError: except URLError:
pytest.skip(f"could not download test file '{url}'") pytest.skip(f"could not download test file '{url}'")
def test_download_url_dont_exist(self, tmpdir): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_dont_exist(self, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip" url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
with pytest.raises(URLError): with pytest.raises(URLError):
utils.download_url(url, tmpdir) utils.download_url(url, tmpdir)
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir, use_pathlib):
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view" url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
id = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV" id = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
...@@ -44,7 +57,7 @@ class TestDatasetUtils: ...@@ -44,7 +57,7 @@ class TestDatasetUtils:
mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive") mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive")
utils.download_url(url, tmpdir, filename, md5) utils.download_url(url, tmpdir, filename, md5)
mocked.assert_called_once_with(id, tmpdir, filename, md5) mocked.assert_called_once_with(id, os.path.expanduser(tmpdir), filename, md5)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,7 +15,7 @@ import urllib.error ...@@ -15,7 +15,7 @@ import urllib.error
import urllib.request import urllib.request
import warnings import warnings
import zipfile import zipfile
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
...@@ -104,7 +104,11 @@ def _get_google_drive_file_id(url: str) -> Optional[str]: ...@@ -104,7 +104,11 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
def download_url( def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 url: str,
root: Union[str, pathlib.Path],
filename: Optional[str] = None,
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None: ) -> None:
"""Download a file from a url and place it in root. """Download a file from a url and place it in root.
...@@ -118,7 +122,7 @@ def download_url( ...@@ -118,7 +122,7 @@ def download_url(
root = os.path.expanduser(root) root = os.path.expanduser(root)
if not filename: if not filename:
filename = os.path.basename(url) filename = os.path.basename(url)
fpath = os.path.join(root, filename) fpath = os.fspath(os.path.join(root, filename))
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
...@@ -203,7 +207,9 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple ...@@ -203,7 +207,9 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
return api_response, content return api_response, content
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
):
"""Download a Google Drive file from and place it in root. """Download a Google Drive file from and place it in root.
Args: Args:
...@@ -217,7 +223,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -217,7 +223,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
root = os.path.expanduser(root) root = os.path.expanduser(root)
if not filename: if not filename:
filename = file_id filename = file_id
fpath = os.path.join(root, filename) fpath = os.fspath(os.path.join(root, filename))
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment