You need to sign in or sign up before continuing.
Unverified Commit 17393cb7 authored by Sofya Lipnitskaya's avatar Sofya Lipnitskaya Committed by GitHub
Browse files

Check if the file is located on Google Drive (#3245)



* Check if the file is located on Google Drive

download_file() detects if the provided URL contains a link to Google
Drive and passes the request through the download_file_from_google_drive()
if it is the case.

* remove lazy re import

* add guard clause

* use urlparse

* add tests

* Fixing lint.

* make id matching more safe

* Fixing type check.

* Remove AnyStr.
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent ad3bbef9
......@@ -3,6 +3,7 @@ import sys
import tempfile
import torchvision.datasets.utils as utils
import unittest
import unittest.mock
import zipfile
import tarfile
import gzip
......@@ -48,6 +49,18 @@ class Tester(unittest.TestCase):
with self.assertRaises(RecursionError):
utils._get_redirect_url(url, max_hops=0)
def test_get_google_drive_file_id(self):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
expected = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
actual = utils._get_google_drive_file_id(url)
assert actual == expected
def test_get_google_drive_file_id_invalid_url(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
assert utils._get_google_drive_file_id(url) is None
def test_download_url(self):
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip"
......@@ -76,6 +89,19 @@ class Tester(unittest.TestCase):
with self.assertRaises(URLError):
utils.download_url(url, temp_dir)
@unittest.mock.patch("torchvision.datasets.utils.download_file_from_google_drive")
def test_download_url_dispatch_download_from_google_drive(self, mock):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
filename = "filename"
md5 = "md5"
with get_tmp_dir() as root:
utils.download_url(url, root, filename, md5)
mock.assert_called_once_with(id, root, filename, md5)
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_zip(self):
with get_tmp_dir() as temp_dir:
......
......@@ -2,8 +2,10 @@ import os
import os.path
import hashlib
import gzip
import re
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar
from urllib.parse import urlparse
import zipfile
import torch
......@@ -56,6 +58,19 @@ def _get_redirect_url(url: str, max_hops: int = 10) -> str:
raise RecursionError(f"Too many redirects: {max_hops + 1})")
def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
if match is None:
return None
return match.group("id")
def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
......@@ -85,6 +100,11 @@ def download_url(
# expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops)
# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)
# download the file
try:
print('Downloading ' + url + ' to ' + fpath)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment