Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
24c0a147
"docs/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "688448db7547be90203440cfd105703d8a853f39"
Unverified
Commit
24c0a147
authored
Mar 09, 2022
by
Philip Meier
Committed by
GitHub
Mar 09, 2022
Browse files
use upstream torchdata datapipes in prototype datasets (#5570)
parent
3b0b6c01
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
83 deletions
+13
-83
torchvision/prototype/datasets/_builtin/imagenet.py
torchvision/prototype/datasets/_builtin/imagenet.py
+10
-3
torchvision/prototype/datasets/_builtin/mnist.py
torchvision/prototype/datasets/_builtin/mnist.py
+3
-19
torchvision/prototype/datasets/utils/_internal.py
torchvision/prototype/datasets/utils/_internal.py
+0
-61
No files found.
torchvision/prototype/datasets/_builtin/imagenet.py
View file @
24c0a147
...
@@ -3,8 +3,16 @@ import pathlib
...
@@ -3,8 +3,16 @@ import pathlib
import
re
import
re
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
BinaryIO
,
Match
,
cast
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
BinaryIO
,
Match
,
cast
from
torchdata.datapipes.iter
import
IterDataPipe
,
LineReader
,
IterKeyZipper
,
Mapper
,
Filter
,
Demultiplexer
from
torchdata.datapipes.iter
import
(
from
torchdata.datapipes.iter
import
TarArchiveReader
IterDataPipe
,
LineReader
,
IterKeyZipper
,
Mapper
,
Filter
,
Demultiplexer
,
TarArchiveReader
,
Enumerator
,
)
from
torchvision.prototype.datasets.utils
import
(
from
torchvision.prototype.datasets.utils
import
(
Dataset
,
Dataset
,
DatasetConfig
,
DatasetConfig
,
...
@@ -16,7 +24,6 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -16,7 +24,6 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE
,
INFINITE_BUFFER_SIZE
,
BUILTIN_DIR
,
BUILTIN_DIR
,
path_comparator
,
path_comparator
,
Enumerator
,
getitem
,
getitem
,
read_mat
,
read_mat
,
hint_sharding
,
hint_sharding
,
...
...
torchvision/prototype/datasets/_builtin/mnist.py
View file @
24c0a147
...
@@ -6,25 +6,9 @@ import string
...
@@ -6,25 +6,9 @@ import string
from
typing
import
Any
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
cast
,
BinaryIO
,
Union
,
Sequence
from
typing
import
Any
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
cast
,
BinaryIO
,
Union
,
Sequence
import
torch
import
torch
from
torchdata.datapipes.iter
import
(
from
torchdata.datapipes.iter
import
IterDataPipe
,
Demultiplexer
,
Mapper
,
Zipper
,
Decompressor
IterDataPipe
,
from
torchvision.prototype.datasets.utils
import
Dataset
,
DatasetConfig
,
DatasetInfo
,
HttpResource
,
OnlineResource
Demultiplexer
,
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
,
hint_sharding
,
hint_shuffling
Mapper
,
Zipper
,
)
from
torchvision.prototype.datasets.utils
import
(
Dataset
,
DatasetConfig
,
DatasetInfo
,
HttpResource
,
OnlineResource
,
)
from
torchvision.prototype.datasets.utils._internal
import
(
Decompressor
,
INFINITE_BUFFER_SIZE
,
hint_sharding
,
hint_shuffling
,
)
from
torchvision.prototype.features
import
Image
,
Label
from
torchvision.prototype.features
import
Image
,
Label
from
torchvision.prototype.utils._internal
import
fromfile
from
torchvision.prototype.utils._internal
import
fromfile
...
...
torchvision/prototype/datasets/utils/_internal.py
View file @
24c0a147
import
enum
import
functools
import
functools
import
gzip
import
lzma
import
os
import
os.path
import
pathlib
import
pathlib
import
pickle
import
pickle
from
typing
import
BinaryIO
from
typing
import
BinaryIO
...
@@ -16,7 +11,6 @@ from typing import (
...
@@ -16,7 +11,6 @@ from typing import (
TypeVar
,
TypeVar
,
Iterator
,
Iterator
,
Dict
,
Dict
,
Optional
,
IO
,
IO
,
Sized
,
Sized
,
)
)
...
@@ -35,11 +29,9 @@ __all__ = [
...
@@ -35,11 +29,9 @@ __all__ = [
"BUILTIN_DIR"
,
"BUILTIN_DIR"
,
"read_mat"
,
"read_mat"
,
"MappingIterator"
,
"MappingIterator"
,
"Enumerator"
,
"getitem"
,
"getitem"
,
"path_accessor"
,
"path_accessor"
,
"path_comparator"
,
"path_comparator"
,
"Decompressor"
,
"read_flo"
,
"read_flo"
,
"hint_sharding"
,
"hint_sharding"
,
]
]
...
@@ -75,15 +67,6 @@ class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
...
@@ -75,15 +67,6 @@ class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
yield
from
iter
(
mapping
.
values
()
if
self
.
drop_key
else
mapping
.
items
())
yield
from
iter
(
mapping
.
values
()
if
self
.
drop_key
else
mapping
.
items
())
class
Enumerator
(
IterDataPipe
[
Tuple
[
int
,
D
]]):
def
__init__
(
self
,
datapipe
:
IterDataPipe
[
D
],
start
:
int
=
0
)
->
None
:
self
.
datapipe
=
datapipe
self
.
start
=
start
def
__iter__
(
self
)
->
Iterator
[
Tuple
[
int
,
D
]]:
yield
from
enumerate
(
self
.
datapipe
,
self
.
start
)
def
_getitem_closure
(
obj
:
Any
,
*
,
items
:
Sequence
[
Any
])
->
Any
:
def
_getitem_closure
(
obj
:
Any
,
*
,
items
:
Sequence
[
Any
])
->
Any
:
for
item
in
items
:
for
item
in
items
:
obj
=
obj
[
item
]
obj
=
obj
[
item
]
...
@@ -123,50 +106,6 @@ def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -
...
@@ -123,50 +106,6 @@ def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -
return
functools
.
partial
(
_path_comparator_closure
,
accessor
=
path_accessor
(
getter
),
value
=
value
)
return
functools
.
partial
(
_path_comparator_closure
,
accessor
=
path_accessor
(
getter
),
value
=
value
)
class
CompressionType
(
enum
.
Enum
):
GZIP
=
"gzip"
LZMA
=
"lzma"
class
Decompressor
(
IterDataPipe
[
Tuple
[
str
,
BinaryIO
]]):
types
=
CompressionType
_DECOMPRESSORS
:
Dict
[
CompressionType
,
Callable
[[
BinaryIO
],
BinaryIO
]]
=
{
types
.
GZIP
:
lambda
file
:
cast
(
BinaryIO
,
gzip
.
GzipFile
(
fileobj
=
file
)),
types
.
LZMA
:
lambda
file
:
cast
(
BinaryIO
,
lzma
.
LZMAFile
(
file
)),
}
def
__init__
(
self
,
datapipe
:
IterDataPipe
[
Tuple
[
str
,
BinaryIO
]],
*
,
type
:
Optional
[
Union
[
str
,
CompressionType
]]
=
None
,
)
->
None
:
self
.
datapipe
=
datapipe
if
isinstance
(
type
,
str
):
type
=
self
.
types
(
type
.
upper
())
self
.
type
=
type
def
_detect_compression_type
(
self
,
path
:
str
)
->
CompressionType
:
if
self
.
type
:
return
self
.
type
# TODO: this needs to be more elaborate
ext
=
os
.
path
.
splitext
(
path
)[
1
]
if
ext
==
".gz"
:
return
self
.
types
.
GZIP
elif
ext
==
".xz"
:
return
self
.
types
.
LZMA
else
:
raise
RuntimeError
(
"FIXME"
)
def
__iter__
(
self
)
->
Iterator
[
Tuple
[
str
,
BinaryIO
]]:
for
path
,
file
in
self
.
datapipe
:
type
=
self
.
_detect_compression_type
(
path
)
decompressor
=
self
.
_DECOMPRESSORS
[
type
]
yield
path
,
decompressor
(
file
)
class
PicklerDataPipe
(
IterDataPipe
):
class
PicklerDataPipe
(
IterDataPipe
):
def
__init__
(
self
,
source_datapipe
:
IterDataPipe
[
Tuple
[
str
,
IO
[
bytes
]]])
->
None
:
def
__init__
(
self
,
source_datapipe
:
IterDataPipe
[
Tuple
[
str
,
IO
[
bytes
]]])
->
None
:
self
.
source_datapipe
=
source_datapipe
self
.
source_datapipe
=
source_datapipe
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment