Unverified Commit 17393cb7 authored by Sofya Lipnitskaya's avatar Sofya Lipnitskaya Committed by GitHub
Browse files

Check if the file is located on Google Drive (#3245)



* Check if the file is located on Google Drive

download_file() detects if the provided URL contains a link to Google
Drive and passes the request through the download_file_from_google_drive()
if it is the case.

* remove lazy re import

* add guard clause

* use urlparse

* add tests

* Fixing lint.

* make id matching more safe

* Fixing type check.

* Remove AnyStr.
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent ad3bbef9
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
import tempfile import tempfile
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
import unittest import unittest
import unittest.mock
import zipfile import zipfile
import tarfile import tarfile
import gzip import gzip
...@@ -48,6 +49,18 @@ class Tester(unittest.TestCase): ...@@ -48,6 +49,18 @@ class Tester(unittest.TestCase):
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
utils._get_redirect_url(url, max_hops=0) utils._get_redirect_url(url, max_hops=0)
def test_get_google_drive_file_id(self):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
expected = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
actual = utils._get_google_drive_file_id(url)
assert actual == expected
def test_get_google_drive_file_id_invalid_url(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
assert utils._get_google_drive_file_id(url) is None
def test_download_url(self): def test_download_url(self):
with get_tmp_dir() as temp_dir: with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip" url = "http://github.com/pytorch/vision/archive/master.zip"
...@@ -76,6 +89,19 @@ class Tester(unittest.TestCase): ...@@ -76,6 +89,19 @@ class Tester(unittest.TestCase):
with self.assertRaises(URLError): with self.assertRaises(URLError):
utils.download_url(url, temp_dir) utils.download_url(url, temp_dir)
@unittest.mock.patch("torchvision.datasets.utils.download_file_from_google_drive")
def test_download_url_dispatch_download_from_google_drive(self, mock):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
filename = "filename"
md5 = "md5"
with get_tmp_dir() as root:
utils.download_url(url, root, filename, md5)
mock.assert_called_once_with(id, root, filename, md5)
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_zip(self): def test_extract_zip(self):
with get_tmp_dir() as temp_dir: with get_tmp_dir() as temp_dir:
......
...@@ -2,8 +2,10 @@ import os ...@@ -2,8 +2,10 @@ import os
import os.path import os.path
import hashlib import hashlib
import gzip import gzip
import re
import tarfile import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar from typing import Any, Callable, List, Iterable, Optional, TypeVar
from urllib.parse import urlparse
import zipfile import zipfile
import torch import torch
...@@ -56,6 +58,19 @@ def _get_redirect_url(url: str, max_hops: int = 10) -> str: ...@@ -56,6 +58,19 @@ def _get_redirect_url(url: str, max_hops: int = 10) -> str:
raise RecursionError(f"Too many redirects: {max_hops + 1})") raise RecursionError(f"Too many redirects: {max_hops + 1})")
def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
if match is None:
return None
return match.group("id")
def download_url( def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None: ) -> None:
...@@ -85,6 +100,11 @@ def download_url( ...@@ -85,6 +100,11 @@ def download_url(
# expand redirect chain if needed # expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops) url = _get_redirect_url(url, max_hops=max_redirect_hops)
# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)
# download the file # download the file
try: try:
print('Downloading ' + url + ' to ' + fpath) print('Downloading ' + url + ' to ' + fpath)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment