Unverified Commit 17393cb7 authored by Sofya Lipnitskaya's avatar Sofya Lipnitskaya Committed by GitHub
Browse files

Check if the file is located on Google Drive (#3245)



* Check if the file is located on Google Drive

download_file() detects if the provided URL contains a link to Google
Drive and passes the request through the download_file_from_google_drive()
if it is the case.

* remove lazy re import

* add guard clause

* use urlparse

* add tests

* Fixing lint.

* make id matching more safe

* Fixing type check.

* Remove AnyStr.
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent ad3bbef9
......@@ -3,6 +3,7 @@ import sys
import tempfile
import torchvision.datasets.utils as utils
import unittest
import unittest.mock
import zipfile
import tarfile
import gzip
......@@ -48,6 +49,18 @@ class Tester(unittest.TestCase):
with self.assertRaises(RecursionError):
utils._get_redirect_url(url, max_hops=0)
def test_get_google_drive_file_id(self):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
expected = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
actual = utils._get_google_drive_file_id(url)
assert actual == expected
def test_get_google_drive_file_id_invalid_url(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
assert utils._get_google_drive_file_id(url) is None
def test_download_url(self):
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip"
......@@ -76,6 +89,19 @@ class Tester(unittest.TestCase):
with self.assertRaises(URLError):
utils.download_url(url, temp_dir)
@unittest.mock.patch("torchvision.datasets.utils.download_file_from_google_drive")
def test_download_url_dispatch_download_from_google_drive(self, mock):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
filename = "filename"
md5 = "md5"
with get_tmp_dir() as root:
utils.download_url(url, root, filename, md5)
mock.assert_called_once_with(id, root, filename, md5)
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_zip(self):
with get_tmp_dir() as temp_dir:
......
......@@ -2,8 +2,10 @@ import os
import os.path
import hashlib
import gzip
import re
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar
from urllib.parse import urlparse
import zipfile
import torch
......@@ -56,6 +58,19 @@ def _get_redirect_url(url: str, max_hops: int = 10) -> str:
raise RecursionError(f"Too many redirects: {max_hops + 1})")
def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
if match is None:
return None
return match.group("id")
def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
......@@ -85,6 +100,11 @@ def download_url(
# expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops)
# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)
# download the file
try:
print('Downloading ' + url + ' to ' + fpath)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment