Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6a936e48
Unverified
Commit
6a936e48
authored
Feb 08, 2024
by
Philip Meier
Committed by
GitHub
Feb 08, 2024
Browse files
add gdown as optional requirement for dataset GDrive download (#8237)
parent
4c0f4414
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
30 additions
and
67 deletions
+30
-67
.github/workflows/tests-schedule.yml
.github/workflows/tests-schedule.yml
+1
-1
mypy.ini
mypy.ini
+4
-0
setup.py
setup.py
+0
-1
torchvision/datasets/caltech.py
torchvision/datasets/caltech.py
+4
-0
torchvision/datasets/celeba.py
torchvision/datasets/celeba.py
+4
-0
torchvision/datasets/pcam.py
torchvision/datasets/pcam.py
+4
-0
torchvision/datasets/utils.py
torchvision/datasets/utils.py
+9
-65
torchvision/datasets/widerface.py
torchvision/datasets/widerface.py
+4
-0
No files found.
.github/workflows/tests-schedule.yml
View file @
6a936e48
...
@@ -36,7 +36,7 @@ jobs:
...
@@ -36,7 +36,7 @@ jobs:
run
:
pip install --no-build-isolation --editable .
run
:
pip install --no-build-isolation --editable .
-
name
:
Install all optional dataset requirements
-
name
:
Install all optional dataset requirements
run
:
pip install scipy pycocotools lmdb
requests
run
:
pip install scipy pycocotools lmdb
gdown
-
name
:
Install tests requirements
-
name
:
Install tests requirements
run
:
pip install pytest
run
:
pip install pytest
...
...
mypy.ini
View file @
6a936e48
...
@@ -142,3 +142,7 @@ ignore_missing_imports = True
...
@@ -142,3 +142,7 @@ ignore_missing_imports = True
[mypy-h5py.*]
[mypy-h5py.*]
ignore_missing_imports
=
True
ignore_missing_imports
=
True
[mypy-gdown.*]
ignore_missing_imports
=
True
setup.py
View file @
6a936e48
...
@@ -59,7 +59,6 @@ if os.getenv("PYTORCH_VERSION"):
...
@@ -59,7 +59,6 @@ if os.getenv("PYTORCH_VERSION"):
requirements
=
[
requirements
=
[
"numpy"
,
"numpy"
,
"requests"
,
pytorch_dep
,
pytorch_dep
,
]
]
...
...
torchvision/datasets/caltech.py
View file @
6a936e48
...
@@ -30,6 +30,10 @@ class Caltech101(VisionDataset):
...
@@ -30,6 +30,10 @@ class Caltech101(VisionDataset):
download (bool, optional): If true, downloads the dataset from the internet and
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
"""
"""
def
__init__
(
def
__init__
(
...
...
torchvision/datasets/celeba.py
View file @
6a936e48
...
@@ -38,6 +38,10 @@ class CelebA(VisionDataset):
...
@@ -38,6 +38,10 @@ class CelebA(VisionDataset):
download (bool, optional): If true, downloads the dataset from the internet and
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
"""
"""
base_folder
=
"celeba"
base_folder
=
"celeba"
...
...
torchvision/datasets/pcam.py
View file @
6a936e48
...
@@ -25,6 +25,10 @@ class PCAM(VisionDataset):
...
@@ -25,6 +25,10 @@ class PCAM(VisionDataset):
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
dataset is already downloaded, it is not downloaded again.
dataset is already downloaded, it is not downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
"""
"""
_FILES
=
{
_FILES
=
{
...
...
torchvision/datasets/utils.py
View file @
6a936e48
import
bz2
import
bz2
import
contextlib
import
gzip
import
gzip
import
hashlib
import
hashlib
import
itertools
import
lzma
import
lzma
import
os
import
os
import
os.path
import
os.path
...
@@ -13,13 +11,11 @@ import tarfile
...
@@ -13,13 +11,11 @@ import tarfile
import
urllib
import
urllib
import
urllib.error
import
urllib.error
import
urllib.request
import
urllib.request
import
warnings
import
zipfile
import
zipfile
from
typing
import
Any
,
Callable
,
Dict
,
IO
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
IO
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
import
numpy
as
np
import
numpy
as
np
import
requests
import
torch
import
torch
from
torch.utils.model_zoo
import
tqdm
from
torch.utils.model_zoo
import
tqdm
...
@@ -191,22 +187,6 @@ def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False
...
@@ -191,22 +187,6 @@ def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False
return
files
return
files
def
_extract_gdrive_api_response
(
response
,
chunk_size
:
int
=
32
*
1024
)
->
Tuple
[
bytes
,
Iterator
[
bytes
]]:
content
=
response
.
iter_content
(
chunk_size
)
first_chunk
=
None
# filter out keep-alive new chunks
while
not
first_chunk
:
first_chunk
=
next
(
content
)
content
=
itertools
.
chain
([
first_chunk
],
content
)
try
:
match
=
re
.
search
(
"<title>Google Drive - (?P<api_response>.+?)</title>"
,
first_chunk
.
decode
())
api_response
=
match
[
"api_response"
]
if
match
is
not
None
else
None
except
UnicodeDecodeError
:
api_response
=
None
return
api_response
,
content
def
download_file_from_google_drive
(
def
download_file_from_google_drive
(
file_id
:
str
,
file_id
:
str
,
root
:
Union
[
str
,
pathlib
.
Path
],
root
:
Union
[
str
,
pathlib
.
Path
],
...
@@ -221,7 +201,12 @@ def download_file_from_google_drive(
...
@@ -221,7 +201,12 @@ def download_file_from_google_drive(
filename (str, optional): Name to save the file under. If None, use the id of the file.
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
try
:
import
gdown
except
ModuleNotFoundError
:
raise
RuntimeError
(
"To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
)
root
=
os
.
path
.
expanduser
(
root
)
root
=
os
.
path
.
expanduser
(
root
)
if
not
filename
:
if
not
filename
:
...
@@ -234,51 +219,10 @@ def download_file_from_google_drive(
...
@@ -234,51 +219,10 @@ def download_file_from_google_drive(
print
(
f
"Using downloaded
{
'and verified '
if
md5
else
''
}
file:
{
fpath
}
"
)
print
(
f
"Using downloaded
{
'and verified '
if
md5
else
''
}
file:
{
fpath
}
"
)
return
return
url
=
"https://drive.google.com/uc"
gdown
.
download
(
id
=
file_id
,
output
=
fpath
,
quiet
=
False
,
user_agent
=
USER_AGENT
)
params
=
dict
(
id
=
file_id
,
export
=
"download"
)
with
requests
.
Session
()
as
session
:
response
=
session
.
get
(
url
,
params
=
params
,
stream
=
True
)
for
key
,
value
in
response
.
cookies
.
items
():
if
not
check_integrity
(
fpath
,
md5
):
if
key
.
startswith
(
"download_warning"
):
raise
RuntimeError
(
"File not found or corrupted."
)
token
=
value
break
else
:
api_response
,
content
=
_extract_gdrive_api_response
(
response
)
token
=
"t"
if
api_response
==
"Virus scan warning"
else
None
if
token
is
not
None
:
response
=
session
.
get
(
url
,
params
=
dict
(
params
,
confirm
=
token
),
stream
=
True
)
api_response
,
content
=
_extract_gdrive_api_response
(
response
)
if
api_response
==
"Quota exceeded"
:
raise
RuntimeError
(
f
"The daily quota of the file
{
filename
}
is exceeded and it "
f
"can't be downloaded. This is a limitation of Google Drive "
f
"and can only be overcome by trying again later."
)
_save_response_content
(
content
,
fpath
)
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
if
os
.
stat
(
fpath
).
st_size
<
10
*
1024
:
with
contextlib
.
suppress
(
UnicodeDecodeError
),
open
(
fpath
)
as
fh
:
text
=
fh
.
read
()
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
if
re
.
search
(
r
"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)"
,
text
):
warnings
.
warn
(
f
"We detected some HTML elements in the downloaded file. "
f
"This most likely means that the download triggered an unhandled API response by GDrive. "
f
"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
f
"the response:
\n\n
{
text
}
"
)
if
md5
and
not
check_md5
(
fpath
,
md5
):
raise
RuntimeError
(
f
"The MD5 checksum of the download file
{
fpath
}
does not match the one on record."
f
"Please delete the file and try again. "
f
"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
)
def
_extract_tar
(
def
_extract_tar
(
...
...
torchvision/datasets/widerface.py
View file @
6a936e48
...
@@ -34,6 +34,10 @@ class WIDERFace(VisionDataset):
...
@@ -34,6 +34,10 @@ class WIDERFace(VisionDataset):
puts it in root directory. If dataset is already downloaded, it is not
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
"""
"""
BASE_FOLDER
=
"widerface"
BASE_FOLDER
=
"widerface"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment