Commit 0fc002df authored by huchen's avatar huchen
Browse files

init the dlexamples new

parent 0e04b692
body {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
/* Default header fonts are ugly */
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
/* Use white for docs background */
.wy-side-nav-search {
background-color: #fff;
}
.wy-nav-content-wrap, .wy-menu li.current > a {
background-color: #fff;
}
@media screen and (min-width: 1400px) {
.wy-nav-content-wrap {
background-color: rgba(0, 0, 0, 0.0470588);
}
.wy-nav-content {
background-color: #fff;
}
}
/* Fixes for mobile */
.wy-nav-top {
background-color: #fff;
background-image: url('../img/pytorch-logo-dark.svg');
background-repeat: no-repeat;
background-position: center;
padding: 0;
margin: 0.4045em 0.809em;
color: #333;
}
.wy-nav-top > a {
display: none;
}
@media screen and (max-width: 768px) {
.wy-side-nav-search>a img.logo {
height: 60px;
}
}
/* This is needed to ensure that logo above search scales properly */
.wy-side-nav-search a {
display: block;
}
/* This ensures that multiple constructors will remain in separate lines. */
.rst-content dl:not(.docutils) dt {
display: table;
}
/* Use our red for literals (it's very similar to the original color) */
.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal {
color: #F05732;
}
.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref,
.rst-content code.xref, a .rst-content tt, a .rst-content code {
color: #404040;
}
/* Change link colors (except for the menu) */
a {
color: #F05732;
}
a:hover {
color: #F05732;
}
a:visited {
color: #D44D2C;
}
.wy-menu a {
color: #b3b3b3;
}
.wy-menu a:hover {
color: #b3b3b3;
}
/* Default footer text is quite big */
footer {
font-size: 80%;
}
footer .rst-footer-buttons {
font-size: 125%; /* revert footer settings - 1/80% = 125% */
}
footer p {
font-size: 100%;
}
/* For hidden headers that appear in TOC tree */
/* see http://stackoverflow.com/a/32363545/3343043 */
.rst-content .hidden-section {
display: none;
}
nav .hidden-section {
display: inherit;
}
.wy-side-nav-search>div.version {
color: #000;
}
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 21.0.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 199.7 40.2" style="enable-background:new 0 0 199.7 40.2;" xml:space="preserve">
<style type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
</style>
<path class="st0" d="M102.7,12.2c-1.3-1-1.8,3.9-4.4,3.9c-3,0-4-13-6.3-13c-0.7,0-0.8-0.4-7.9,21.3c-2.9,9,4.4,15.8,11.8,15.8
c4.6,0,12.3-3,12.3-12.6C108.2,20.5,104.7,13.7,102.7,12.2z M95.8,35.3c-3.7,0-6.7-3.1-6.7-7c0-3.9,3-7,6.7-7s6.7,3.1,6.7,7
C102.5,32.1,99.5,35.3,95.8,35.3z"/>
<path class="st1" d="M99.8,0c-0.5,0-1.8,2.5-1.8,3.6c0,1.5,1,2,1.8,2c0.8,0,1.8-0.5,1.8-2C101.5,2.5,100.2,0,99.8,0z"/>
<path class="st2" d="M0,39.5V14.9h11.5c5.3,0,8.3,3.6,8.3,7.9c0,4.3-3,7.9-8.3,7.9H5.2v8.8H0z M14.4,22.8c0-2.1-1.6-3.3-3.7-3.3H5.2
v6.6h5.5C12.8,26.1,14.4,24.8,14.4,22.8z"/>
<path class="st2" d="M35.2,39.5V29.4l-9.4-14.5h6l6.1,9.8l6.1-9.8h5.9l-9.4,14.5v10.1H35.2z"/>
<path class="st2" d="M63.3,39.5v-20h-7.2v-4.6h19.6v4.6h-7.2v20H63.3z"/>
<path class="st2" d="M131.4,39.5l-4.8-8.7h-3.8v8.7h-5.2V14.9H129c5.1,0,8.3,3.4,8.3,7.9c0,4.3-2.8,6.7-5.4,7.3l5.6,9.4H131.4z
M131.9,22.8c0-2-1.6-3.3-3.7-3.3h-5.5v6.6h5.5C130.3,26.1,131.9,24.9,131.9,22.8z"/>
<path class="st2" d="M145.6,27.2c0-7.6,5.7-12.7,13.1-12.7c5.4,0,8.5,2.9,10.3,6l-4.5,2.2c-1-2-3.2-3.6-5.8-3.6
c-4.5,0-7.7,3.4-7.7,8.1c0,4.6,3.2,8.1,7.7,8.1c2.5,0,4.7-1.6,5.8-3.6l4.5,2.2c-1.7,3.1-4.9,6-10.3,6
C151.3,39.9,145.6,34.7,145.6,27.2z"/>
<path class="st2" d="M194.5,39.5V29.1h-11.6v10.4h-5.2V14.9h5.2v9.7h11.6v-9.7h5.3v24.6H194.5z"/>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
height="40.200001"
width="40.200001"
xml:space="preserve"
viewBox="0 0 40.200002 40.2"
y="0px"
x="0px"
id="Layer_1"
version="1.1"><metadata
id="metadata4717"><rdf:RDF><cc:Work
rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
id="defs4715" /><style
id="style4694"
type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
</style><path
style="fill:#f05732"
id="path4696"
d="m 26.975479,12.199999 c -1.3,-1 -1.8,3.9 -4.4,3.9 -3,0 -4,-12.9999998 -6.3,-12.9999998 -0.7,0 -0.8,-0.4 -7.9000003,21.2999998 -2.9000001,9 4.4000003,15.8 11.8000003,15.8 4.6,0 12.3,-3 12.3,-12.6 0,-7.1 -3.5,-13.9 -5.5,-15.4 z m -6.9,23.1 c -3.7,0 -6.7,-3.1 -6.7,-7 0,-3.9 3,-7 6.7,-7 3.7,0 6.7,3.1 6.7,7 0,3.8 -3,7 -6.7,7 z"
class="st0" /><path
style="fill:#9e529f"
id="path4698"
d="m 24.075479,-7.6293945e-7 c -0.5,0 -1.8,2.49999996293945 -1.8,3.59999996293945 0,1.5 1,2 1.8,2 0.8,0 1.8,-0.5 1.8,-2 -0.1,-1.1 -1.4,-3.59999996293945 -1.8,-3.59999996293945 z"
class="st1" /></svg>
\ No newline at end of file
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import torch
import torchvision
import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinxcontrib.googleanalytics',
]
napoleon_use_ivar = True
googleanalytics_id = 'UA-90545585-1'
googleanalytics_enabled = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = 'Torchvision'
copyright = '2017, Torch Contributors'
author = 'Torch Contributors'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
# TODO: change to [:2] at v1.0
version = 'master (' + torchvision.__version__ + ' )'
# The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected
release = 'master'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'pytorch_sphinx_theme'
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {
'collapse_navigation': False,
'display_version': True,
'logo_only': True,
'pytorch_project': 'docs',
}
html_logo = '_static/img/pytorch-logo-dark.svg'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# html_style_path = 'css/pytorch_theme.css'
# html_context = {
# 'css_files': [
# 'https://fonts.googleapis.com/css?family=Lato',
# '_static/css/pytorch_theme.css'
# ],
# }
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'pytorch.tex', 'torchvision Documentation',
'Torch Contributors', 'manual'),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchvision', 'torchvision Documentation',
[author], 1)
]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'torchvision', 'torchvision Documentation',
author, 'torchvision', 'One line description of project.',
'Miscellaneous'),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
}
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes
def patched_make_field(self, types, domain, items, **kw):
# `kw` catches `env=None` needed for newer sphinx while maintaining
# backwards compatibility when passed along further down!
# type: (list, unicode, tuple) -> nodes.field # noqa: F821
def handle_item(fieldarg, content):
par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong))
if fieldarg in types:
par += nodes.Text(' (')
# NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int')
typename = typename.replace('long', 'python:long')
typename = typename.replace('float', 'python:float')
typename = typename.replace('type', 'python:type')
par.extend(self.make_xrefs(self.typerolename, domain, typename,
addnodes.literal_emphasis, **kw))
else:
par += fieldtype
par += nodes.Text(')')
par += nodes.Text(' -- ')
par += content
return par
fieldname = nodes.field_name('', self.label)
if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content)
else:
bodynode = self.list_type()
for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody)
TypedField.make_field = patched_make_field
torchvision.datasets
====================
All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
For example: ::
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
The following datasets are available:
.. contents:: Datasets
:local:
All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively.
.. currentmodule:: torchvision.datasets
CelebA
~~~~~~
.. autoclass:: CelebA
:members: __getitem__
:special-members:
CIFAR
~~~~~
.. autoclass:: CIFAR10
:members: __getitem__
:special-members:
.. autoclass:: CIFAR100
Cityscapes
~~~~~~~~~~
.. note ::
Requires Cityscape to be downloaded.
.. autoclass:: Cityscapes
:members: __getitem__
:special-members:
COCO
~~~~
.. note ::
These require the `COCO API to be installed`_
.. _COCO API to be installed: https://github.com/pdollar/coco/tree/master/PythonAPI
Captions
^^^^^^^^
.. autoclass:: CocoCaptions
:members: __getitem__
:special-members:
Detection
^^^^^^^^^
.. autoclass:: CocoDetection
:members: __getitem__
:special-members:
DatasetFolder
~~~~~~~~~~~~~
.. autoclass:: DatasetFolder
:members: __getitem__
:special-members:
EMNIST
~~~~~~
.. autoclass:: EMNIST
FakeData
~~~~~~~~
.. autoclass:: FakeData
Fashion-MNIST
~~~~~~~~~~~~~
.. autoclass:: FashionMNIST
Flickr
~~~~~~
.. autoclass:: Flickr8k
:members: __getitem__
:special-members:
.. autoclass:: Flickr30k
:members: __getitem__
:special-members:
HMDB51
~~~~~~~
.. autoclass:: HMDB51
:members: __getitem__
:special-members:
ImageFolder
~~~~~~~~~~~
.. autoclass:: ImageFolder
:members: __getitem__
:special-members:
ImageNet
~~~~~~~~~~~
.. autoclass:: ImageNet
.. note ::
This requires `scipy` to be installed
Kinetics-400
~~~~~~~~~~~~
.. autoclass:: Kinetics400
:members: __getitem__
:special-members:
KMNIST
~~~~~~~~~~~~~
.. autoclass:: KMNIST
LSUN
~~~~
.. autoclass:: LSUN
:members: __getitem__
:special-members:
MNIST
~~~~~
.. autoclass:: MNIST
Omniglot
~~~~~~
.. autoclass:: Omniglot
PhotoTour
~~~~~~~~~
.. autoclass:: PhotoTour
:members: __getitem__
:special-members:
Places365
~~~~~~~~~
.. autoclass:: Places365
:members: __getitem__
:special-members:
QMNIST
~~~~~~
.. autoclass:: QMNIST
SBD
~~~~~~
.. autoclass:: SBDataset
:members: __getitem__
:special-members:
SBU
~~~
.. autoclass:: SBU
:members: __getitem__
:special-members:
STL10
~~~~~
.. autoclass:: STL10
:members: __getitem__
:special-members:
SVHN
~~~~~
.. autoclass:: SVHN
:members: __getitem__
:special-members:
UCF101
~~~~~~~
.. autoclass:: UCF101
:members: __getitem__
:special-members:
USPS
~~~~~
.. autoclass:: USPS
:members: __getitem__
:special-members:
VOC
~~~~~~
.. autoclass:: VOCSegmentation
:members: __getitem__
:special-members:
.. autoclass:: VOCDetection
:members: __getitem__
:special-members:
torchvision
===========
This library is part of the `PyTorch
<http://pytorch.org/>`_ project. PyTorch is an open source
machine learning framework.
Features described in this documentation are classified by release status:
*Stable:* These features will be maintained long-term and there should generally
be no major performance limitations or gaps in documentation.
We also expect to maintain backwards compatibility (although
breaking changes can happen and notice will be given one release ahead
of time).
*Beta:* Features are tagged as Beta because the API may change based on
user feedback, because the performance needs to improve, or because
coverage across operators is not yet complete. For Beta features, we are
committing to seeing the feature through to the Stable classification.
We are not, however, committing to backwards compatibility.
*Prototype:* These features are typically not available as part of
binary distributions like PyPI or Conda, except sometimes behind run-time
flags, and are at an early stage for feedback and testing.
The :mod:`torchvision` package consists of popular datasets, model
architectures, and common image transformations for computer vision.
.. toctree::
:maxdepth: 2
:caption: Package Reference
datasets
io
models
ops
transforms
utils
.. automodule:: torchvision
:members:
.. toctree::
:maxdepth: 1
:caption: PyTorch Libraries
PyTorch <https://pytorch.org/docs>
torchaudio <https://pytorch.org/audio>
torchtext <https://pytorch.org/text>
torchvision <https://pytorch.org/vision>
TorchElastic <https://pytorch.org/elastic/>
TorchServe <https://pytorch.org/serve>
PyTorch on XLA Devices <http://pytorch.org/xla/>
torchvision.io
==============
.. currentmodule:: torchvision.io
The :mod:`torchvision.io` package provides functions for performing IO
operations. They are currently specific to reading and writing video and
images.
Video
-----
.. autofunction:: read_video
.. autofunction:: read_video_timestamps
.. autofunction:: write_video
Fine-grained video API
-------------------
In addition to the :mod:`read_video` function, we provide a high-performance
lower-level API for more fine-grained control compared to the :mod:`read_video` function.
It does all this whilst fully supporting torchscript.
.. autoclass:: VideoReader
:members: __next__, get_metadata, set_current_stream, seek
Example of inspecting a video:
.. code:: python
import torchvision
video_path = "path to a test video"
# Constructor allocates memory and a threaded decoder
# instance per video. At the momet it takes two arguments:
# path to the video file, and a wanted stream.
reader = torchvision.io.VideoReader(video_path, "video")
# The information about the video can be retrieved using the
# `get_metadata()` method. It returns a dictionary for every stream, with
# duration and other relevant metadata (often frame rate)
reader_md = reader.get_metadata()
# metadata is structured as a dict of dicts with following structure
# {"stream_type": {"attribute": [attribute per stream]}}
#
# following would print out the list of frame rates for every present video stream
print(reader_md["video"]["fps"])
# we explicitly select the stream we would like to operate on. In
# the constructor we select a default video stream, but
# in practice, we can set whichever stream we would like
video.set_current_stream("video:0")
Image
-----
.. autofunction:: read_image
.. autofunction:: decode_image
.. autofunction:: encode_jpeg
.. autofunction:: write_jpeg
.. autofunction:: encode_png
.. autofunction:: write_png
torchvision.models
##################
The models subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection and video classification.
Classification
==============
The models subpackage contains definitions for the following model
architectures for image classification:
- `AlexNet`_
- `VGG`_
- `ResNet`_
- `SqueezeNet`_
- `DenseNet`_
- `Inception`_ v3
- `GoogLeNet`_
- `ShuffleNet`_ v2
- `MobileNet`_ v2
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
You can construct a model with random weights by calling its constructor:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
:func:`torch.utils.model_zoo.load_url` for details.
Some models use modules which have different training and evaluation
behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
You can use the following transform to normalize::
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
An example of such normalization can be found in the imagenet example
`here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`_
The process for obtaining the values of `mean` and `std` is roughly equivalent
to::
import torch
from torchvision import datasets, transforms as T
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)
means = []
stds = []
for img in subset(dataset):
means.append(torch.mean(img))
stds.append(torch.std(img))
mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))
Unfortunately, the concrete `subset` that was used is lost. For more
information see `this discussion <https://github.com/pytorch/vision/issues/1439>`_
or `these experiments <https://github.com/pytorch/vision/pull/1965>`_.
ImageNet 1-crop error rates (224x224)
================================ ============= =============
Network Top-1 error Top-5 error
================================ ============= =============
AlexNet 43.45 20.91
VGG-11 30.98 11.37
VGG-13 30.07 10.75
VGG-16 28.41 9.62
VGG-19 27.62 9.12
VGG-11 with batch normalization 29.62 10.19
VGG-13 with batch normalization 28.45 9.63
VGG-16 with batch normalization 26.63 8.50
VGG-19 with batch normalization 25.76 8.15
ResNet-18 30.24 10.92
ResNet-34 26.70 8.58
ResNet-50 23.85 7.13
ResNet-101 22.63 6.44
ResNet-152 21.69 5.94
SqueezeNet 1.0 41.90 19.58
SqueezeNet 1.1 41.81 19.38
Densenet-121 25.35 7.83
Densenet-169 24.00 7.00
Densenet-201 22.80 6.43
Densenet-161 22.35 6.20
Inception v3 22.55 6.44
GoogleNet 30.22 10.47
ShuffleNet V2 30.64 11.68
MobileNet V2 28.12 9.71
ResNeXt-50-32x4d 22.38 6.30
ResNeXt-101-32x8d 20.69 5.47
Wide ResNet-50-2 21.49 5.91
Wide ResNet-101-2 21.16 5.72
MNASNet 1.0 26.49 8.456
================================ ============= =============
.. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
.. _DenseNet: https://arxiv.org/abs/1608.06993
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNet: https://arxiv.org/abs/1801.04381
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626
.. currentmodule:: torchvision.models
Alexnet
-------
.. autofunction:: alexnet
VGG
---
.. autofunction:: vgg11
.. autofunction:: vgg11_bn
.. autofunction:: vgg13
.. autofunction:: vgg13_bn
.. autofunction:: vgg16
.. autofunction:: vgg16_bn
.. autofunction:: vgg19
.. autofunction:: vgg19_bn
ResNet
------
.. autofunction:: resnet18
.. autofunction:: resnet34
.. autofunction:: resnet50
.. autofunction:: resnet101
.. autofunction:: resnet152
SqueezeNet
----------
.. autofunction:: squeezenet1_0
.. autofunction:: squeezenet1_1
DenseNet
---------
.. autofunction:: densenet121
.. autofunction:: densenet169
.. autofunction:: densenet161
.. autofunction:: densenet201
Inception v3
------------
.. autofunction:: inception_v3
.. note ::
This requires `scipy` to be installed
GoogLeNet
------------
.. autofunction:: googlenet
.. note ::
This requires `scipy` to be installed
ShuffleNet v2
-------------
.. autofunction:: shufflenet_v2_x0_5
.. autofunction:: shufflenet_v2_x1_0
.. autofunction:: shufflenet_v2_x1_5
.. autofunction:: shufflenet_v2_x2_0
MobileNet v2
-------------
.. autofunction:: mobilenet_v2
ResNext
-------
.. autofunction:: resnext50_32x4d
.. autofunction:: resnext101_32x8d
Wide ResNet
-----------
.. autofunction:: wide_resnet50_2
.. autofunction:: wide_resnet101_2
MNASNet
--------
.. autofunction:: mnasnet0_5
.. autofunction:: mnasnet0_75
.. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3
Semantic Segmentation
=====================
The models subpackage contains definitions for the following model
architectures for semantic segmentation:
- `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_
- `DeepLabV3 ResNet50, ResNet101 <https://arxiv.org/abs/1706.05587>`_
As with image classification models, all pre-trained models expect input images normalized in the same way.
The images have to be loaded in to a range of ``[0, 1]`` and then normalized using
``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
They have been trained on images resized such that their minimum size is 520.
The pre-trained models have been trained on a subset of COCO train2017, on the 20 categories that are
present in the Pascal VOC dataset. You can see more information on how the subset has been selected in
``references/segmentation/coco_utils.py``. The classes that the pre-trained model outputs are the following,
in order:
.. code-block:: python
['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
The accuracies of the pre-trained models evaluated on COCO val2017 are as follows
================================ ============= ====================
Network mean IoU global pixelwise acc
================================ ============= ====================
FCN ResNet50 60.5 91.4
FCN ResNet101 63.7 91.9
DeepLabV3 ResNet50 66.4 92.4
DeepLabV3 ResNet101 67.4 92.4
================================ ============= ====================
Fully Convolutional Networks
----------------------------
.. autofunction:: torchvision.models.segmentation.fcn_resnet50
.. autofunction:: torchvision.models.segmentation.fcn_resnet101
DeepLabV3
---------
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================
The models subpackage contains definitions for the following model
architectures for detection:
- `Faster R-CNN ResNet-50 FPN <https://arxiv.org/abs/1506.01497>`_
- `Mask R-CNN ResNet-50 FPN <https://arxiv.org/abs/1703.06870>`_
The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models
in torchvision.
The models expect a list of ``Tensor[C, H, W]``, in the range ``0-1``.
The models internally resize the images so that they have a minimum size
of ``800``. This option can be changed by passing the option ``min_size``
to the constructor of the models.
For object detection and instance segmentation, the pre-trained
models return the predictions of the following classes:
.. code-block:: python
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
Here are the summary of the accuracies for the models trained on
the instances set of COCO train2017 and evaluated on COCO val2017.
================================ ======= ======== ===========
Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================ ======= ======== ===========
For person keypoint detection, the accuracies for the pre-trained
models are as follows
================================ ======= ======== ===========
Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Keypoint R-CNN ResNet-50 FPN 54.6 - 65.0
================================ ======= ======== ===========
For person keypoint detection, the pre-trained model return the
keypoints in the following order:
.. code-block:: python
COCO_PERSON_KEYPOINT_NAMES = [
'nose',
'left_eye',
'right_eye',
'left_ear',
'right_ear',
'left_shoulder',
'right_shoulder',
'left_elbow',
'right_elbow',
'left_wrist',
'right_wrist',
'left_hip',
'right_hip',
'left_knee',
'right_knee',
'left_ankle',
'right_ankle'
]
Runtime characteristics
-----------------------
The implementations of the models for object detection, instance segmentation
and keypoint detection are efficient.
In the following table, we use 8 V100 GPUs, with CUDA 10.0 and CUDNN 7.4 to
report the results. During training, we use a batch size of 2 per GPU, and
during testing a batch size of 1 is used.
For test time, we report the time for the model evaluation and postprocessing
(including mask pasting in image), but not the time for computing the
precision-recall.
============================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
============================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
============================== =================== ================== ===========
Faster R-CNN
------------
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
RetinaNet
------------
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn
Mask R-CNN
----------
.. autofunction:: torchvision.models.detection.maskrcnn_resnet50_fpn
Keypoint R-CNN
--------------
.. autofunction:: torchvision.models.detection.keypointrcnn_resnet50_fpn
Video classification
====================
We provide models for action recognition pre-trained on Kinetics-400.
They have all been trained with the scripts provided in ``references/video_classification``.
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB videos of shape (3 x T x H x W),
where H and W are expected to be 112, and T is a number of video frames in a clip.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.43216, 0.394666, 0.37645]`` and ``std = [0.22803, 0.22145, 0.216989]``.
.. note::
The normalization parameters are different from the image classification ones, and correspond
to the mean and std from Kinetics-400.
.. note::
For now, normalization code can be found in ``references/video_classification/transforms.py``,
see the ``Normalize`` function there. Note that it differs from standard normalization for
images because it assumes the video is 4d.
Kinetics 1-crop accuracies for clip length 16 (16x112x112)
================================ ============= =============
Network Clip acc@1 Clip acc@5
================================ ============= =============
ResNet 3D 18 52.75 75.45
ResNet MC 18 53.90 76.29
ResNet (2+1)D 57.50 78.81
================================ ============= =============
ResNet 3D
----------
.. autofunction:: torchvision.models.video.r3d_18
ResNet Mixed Convolution
------------------------
.. autofunction:: torchvision.models.video.mc3_18
ResNet (2+1)D
-------------
.. autofunction:: torchvision.models.video.r2plus1d_18
torchvision.ops
===============
.. currentmodule:: torchvision.ops
:mod:`torchvision.ops` implements operators that are specific for Computer Vision.
.. note::
All operators have native support for TorchScript.
.. autofunction:: nms
.. autofunction:: batched_nms
.. autofunction:: remove_small_boxes
.. autofunction:: clip_boxes_to_image
.. autofunction:: box_convert
.. autofunction:: box_area
.. autofunction:: box_iou
.. autofunction:: generalized_box_iou
.. autofunction:: roi_align
.. autofunction:: ps_roi_align
.. autofunction:: roi_pool
.. autofunction:: ps_roi_pool
.. autofunction:: deform_conv2d
.. autoclass:: RoIAlign
.. autoclass:: PSRoIAlign
.. autoclass:: RoIPool
.. autoclass:: PSRoIPool
.. autoclass:: DeformConv2d
.. autoclass:: MultiScaleRoIAlign
.. autoclass:: FeaturePyramidNetwork
torchvision.transforms
======================
.. currentmodule:: torchvision.transforms
Transforms are common image transformations. They can be chained together using :class:`Compose`.
Additionally, there is the :mod:`torchvision.transforms.functional` module.
Functional transforms give fine-grained control over the transformations.
This is useful if you have to build a more complex transformation pipeline
(e.g. in the case of segmentation tasks).
All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Tensor Image is a tensor with
``(C, H, W)`` shape, where ``C`` is a number of channels, ``H`` and ``W`` are image height and width. Batch of
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
.. warning::
Since v0.8.0 all random transformations are using torch default random generator to sample random parameters.
It is a backward compatibility breaking change and user should set the random state as following:
.. code:: python
# Previous versions
# import random
# random.seed(12)
# Now
import torch
torch.manual_seed(17)
Please, keep in mind that the same seed for torch random generator and Python random generator will not
produce the same results.
Scriptable transforms
---------------------
In order to script the transformations, please use ``torch.nn.Sequential`` instead of :class:`Compose`.
.. code:: python
transforms = torch.nn.Sequential(
transforms.CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor`` and does not require
`lambda` functions or ``PIL.Image``.
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
Compositions of transforms
--------------------------
.. autoclass:: Compose
Transforms on PIL Image and torch.\*Tensor
------------------------------------------
.. autoclass:: CenterCrop
:members:
.. autoclass:: ColorJitter
:members:
.. autoclass:: FiveCrop
:members:
.. autoclass:: Grayscale
:members:
.. autoclass:: Pad
:members:
.. autoclass:: RandomAffine
:members:
.. autoclass:: RandomApply
.. autoclass:: RandomCrop
:members:
.. autoclass:: RandomGrayscale
:members:
.. autoclass:: RandomHorizontalFlip
:members:
.. autoclass:: RandomPerspective
:members:
.. autoclass:: RandomResizedCrop
:members:
.. autoclass:: RandomRotation
:members:
.. autoclass:: RandomSizedCrop
:members:
.. autoclass:: RandomVerticalFlip
:members:
.. autoclass:: Resize
:members:
.. autoclass:: Scale
:members:
.. autoclass:: TenCrop
:members:
.. autoclass:: GaussianBlur
:members:
Transforms on PIL Image only
----------------------------
.. autoclass:: RandomChoice
.. autoclass:: RandomOrder
Transforms on torch.\*Tensor only
---------------------------------
.. autoclass:: LinearTransformation
:members:
.. autoclass:: Normalize
:members:
.. autoclass:: RandomErasing
:members:
.. autoclass:: ConvertImageDtype
Conversion Transforms
---------------------
.. autoclass:: ToPILImage
:members:
.. autoclass:: ToTensor
:members:
Generic Transforms
------------------
.. autoclass:: Lambda
:members:
Functional Transforms
---------------------
Functional transforms give you fine-grained control of the transformation pipeline.
As opposed to the transformations above, functional transforms don't contain a random number
generator for their parameters.
That means you have to specify/generate all parameters, but you can reuse the functional transform.
Example:
you can apply a functional transform with the same parameters to multiple images like this:
.. code:: python
import torchvision.transforms.functional as TF
import random
def my_segmentation_transforms(image, segmentation):
if random.random() > 0.5:
angle = random.randint(-30, 30)
image = TF.rotate(image, angle)
segmentation = TF.rotate(segmentation, angle)
# more transforms ...
return image, segmentation
Example:
you can use a functional transform to build transform classes with custom behavior:
.. code:: python
import torchvision.transforms.functional as TF
import random
class MyRotationTransform:
"""Rotate by one of the given angles."""
def __init__(self, angles):
self.angles = angles
def __call__(self, x):
angle = random.choice(self.angles)
return TF.rotate(x, angle)
rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])
.. automodule:: torchvision.transforms.functional
:members:
torchvision.utils
=================
.. currentmodule:: torchvision.utils
.. autofunction:: make_grid
.. autofunction:: save_image
cmake_minimum_required(VERSION 3.10)
project(hello-world)
# The first thing do is to tell cmake to find the TorchVision library.
# The package pulls in all the necessary torch libraries,
# so there is no need to also add `find_package(Torch)` here.
find_package(TorchVision REQUIRED)
add_executable(hello-world main.cpp)
# We now need to link the TorchVision library to our executable.
# We can do that by using the TorchVision::TorchVision target,
# which also adds all the necessary torch dependencies.
target_compile_features(hello-world PUBLIC cxx_range_for)
target_link_libraries(hello-world TorchVision::TorchVision)
set_property(TARGET hello-world PROPERTY CXX_STANDARD 14)
Hello World!
============
This is a minimal example of getting TorchVision to work in C++ with CMake.
In order to successfully compile this example, make sure you have both ``LibTorch`` and
``TorchVision`` installed.
Once both dependencies are sorted, we can start the CMake fun:
1) Create a ``build`` directory inside the current one.
2) from within the ``build`` directory, run the following commands:
- | ``cmake -DCMAKE_PREFIX_PATH="<PATH_TO_LIBTORCH>;<PATH_TO_TORCHVISION>" ..``
| where ``<PATH_TO_LIBTORCH>`` and ``<PATH_TO_TORCHVISION>`` are the paths to the libtorch and torchvision installations.
- ``cmake --build .``
| That's it!
| You should now have a ``hello-world`` executable in your ``build`` folder.
Running it will output a (fairly long) tensor of random values to your terminal.
\ No newline at end of file
#include <iostream>
#include <torchvision/models/resnet.h>
int main()
{
auto model = vision::models::ResNet18();
model->eval();
// Create a random input tensor and run it through the model.
auto in = torch::rand({1, 3, 10, 10});
auto out = model->forward(in);
std::cout << out.sizes();
if (torch::cuda::is_available()) {
// Move model and inputs to GPU
model->to(torch::kCUDA);
auto gpu_in = in.to(torch::kCUDA);
auto gpu_out = model->forward(gpu_in);
std::cout << gpu_out.sizes();
}
}
# Python examples
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/video_api.ipynb)
[Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/video_api.ipynb)
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to
that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new
features:
- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
Furthermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vjAC2mZnb4nz"
},
"source": [
"# Image transformations\n",
"\n",
"This notebook shows new features of torchvision image transformations. \n",
"\n",
"Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new \n",
"features:\n",
"- transform multi-band torch tensor images (with more than 3-4 channels) \n",
"- torchscript transforms together with your model for deployment\n",
"- support for GPU acceleration\n",
"- batched transformation such as for videos\n",
"- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "btaDWPDbgIyW",
"outputId": "8a83d408-f643-42da-d247-faf3a1bd3ae0"
},
"outputs": [],
"source": [
"import torch, torchvision\n",
"torch.__version__, torchvision.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Vj9draNb4oA"
},
"source": [
"## Transforms on CPU/CUDA tensor images\n",
"\n",
"Let's show how to apply transformations on images opened directly as a torch tensors.\n",
"Now, torchvision provides image reading functions for PNG and JPG images with torchscript support. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Epp3hCy0b4oD"
},
"outputs": [],
"source": [
"from torchvision.datasets.utils import download_url\n",
"\n",
"download_url(\"https://farm1.static.flickr.com/152/434505223_8d1890e1e2.jpg\", \".\", \"test-image.jpg\")\n",
"download_url(\"https://farm3.static.flickr.com/2142/1896267403_24939864ba.jpg\", \".\", \"test-image2.jpg\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-m7lYDPb4oK"
},
"outputs": [],
"source": [
"import matplotlib.pylab as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 303
},
"id": "5bi8Q7L3b4oc",
"outputId": "e5de5c73-e16d-4992-ebee-94c7ddf0bf54"
},
"outputs": [],
"source": [
"from torchvision.io.image import read_image\n",
"\n",
"tensor_image = read_image(\"test-image.jpg\")\n",
"\n",
"print(\"tensor image info: \", tensor_image.shape, tensor_image.dtype)\n",
"\n",
"plt.imshow(tensor_image.numpy().transpose((1, 2, 0)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def to_rgb_image(tensor):\n",
" \"\"\"Helper method to get RGB numpy array for plotting\"\"\"\n",
" np_img = tensor.cpu().numpy().transpose((1, 2, 0))\n",
" m1, m2 = np_img.min(axis=(0, 1)), np_img.max(axis=(0, 1))\n",
" return (255.0 * (np_img - m1) / (m2 - m1)).astype(\"uint8\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 322
},
"id": "PgWpjxQ3b4pF",
"outputId": "e9a138e8-b45c-4f75-d849-3b41de0e5472"
},
"outputs": [],
"source": [
"import torchvision.transforms as T\n",
"\n",
"# to fix random seed is now:\n",
"torch.manual_seed(12)\n",
"\n",
"transforms = T.Compose([\n",
" T.RandomCrop(224),\n",
" T.RandomHorizontalFlip(p=0.3),\n",
" T.ConvertImageDtype(torch.float),\n",
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
"])\n",
"\n",
"out_image = transforms(tensor_image)\n",
"print(\"output tensor image info: \", out_image.shape, out_image.dtype)\n",
"\n",
"plt.imshow(to_rgb_image(out_image))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmYQB4cxb4pI"
},
"source": [
"Tensor images can be on GPU"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 322
},
"id": "S6syYJGEb4pN",
"outputId": "86bddb64-e648-45f2-c216-790d43cfc26d"
},
"outputs": [],
"source": [
"out_image = transforms(tensor_image.to(\"cuda\"))\n",
"print(\"output tensor image info: \", out_image.shape, out_image.dtype, out_image.device)\n",
"\n",
"plt.imshow(to_rgb_image(out_image))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jg9TQd7ajfyn"
},
"source": [
"## Scriptable transforms for easier deployment via torchscript\n",
"\n",
"Next, we show how to combine input transformations and model's forward pass and use `torch.jit.script` to obtain a single scripted module.\n",
"\n",
"**Note:** we have to use only scriptable transformations that should be derived from `torch.nn.Module`. \n",
"Since v0.8.0, all transformations are scriptable except `Compose`, `RandomChoice`, `RandomOrder`, `Lambda` and those applied on PIL images.\n",
"The transformations like `Compose` are kept for backward compatibility and can be easily replaced by existing torch modules, like `nn.Sequential`.\n",
"\n",
"Let's define a module `Predictor` that transforms input tensor and applies ImageNet pretrained resnet18 model on it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NSDOJ3RajfvO"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as T\n",
"from torchvision.io.image import read_image\n",
"from torchvision.models import resnet18\n",
"\n",
"\n",
"class Predictor(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.resnet18 = resnet18(pretrained=True).eval()\n",
" self.transforms = nn.Sequential(\n",
" T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions\n",
" T.CenterCrop(224),\n",
" T.ConvertImageDtype(torch.float),\n",
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" with torch.no_grad():\n",
" x = self.transforms(x)\n",
" y_pred = self.resnet18(x)\n",
" return y_pred.argmax(dim=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZZKDovqej5vA"
},
"source": [
"Now, let's define scripted and non-scripted instances of `Predictor` and apply on multiple tensor images of the same size"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GBBMSo7vjfr0"
},
"outputs": [],
"source": [
"from torchvision.io.image import read_image\n",
"\n",
"predictor = Predictor().to(\"cuda\")\n",
"scripted_predictor = torch.jit.script(predictor).to(\"cuda\")\n",
"\n",
"\n",
"tensor_image1 = read_image(\"test-image.jpg\")\n",
"tensor_image2 = read_image(\"test-image2.jpg\")\n",
"batch = torch.stack([tensor_image1[:, -320:, :], tensor_image2[:, -320:, :]]).to(\"cuda\")\n",
"\n",
"res1 = scripted_predictor(batch)\n",
"res2 = predictor(batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 501
},
"id": "Dmi9r_p-oKsk",
"outputId": "b9c55e7d-5db1-4975-c485-fecc4075bf47"
},
"outputs": [],
"source": [
"import json\n",
"from torchvision.datasets.utils import download_url\n",
"\n",
"\n",
"download_url(\"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json\", \".\", \"imagenet_class_index.json\")\n",
"\n",
"\n",
"with open(\"imagenet_class_index.json\", \"r\") as h:\n",
" labels = json.load(h)\n",
"\n",
"\n",
"plt.figure(figsize=(12, 7))\n",
"for i, p in enumerate(res1):\n",
" plt.subplot(1, 2, i + 1)\n",
" plt.title(\"Scripted predictor:\\n{label})\".format(label=labels[str(p.item())]))\n",
" plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))\n",
"\n",
"\n",
"plt.figure(figsize=(12, 7))\n",
"for i, p in enumerate(res2):\n",
" plt.subplot(1, 2, i + 1)\n",
" plt.title(\"Original predictor:\\n{label})\".format(label=labels[str(p.item())]))\n",
" plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7IYsjzpFqcK8"
},
"source": [
"We save and reload scripted predictor in Python or C++ and use it for inference:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"id": "0kk9LLw5jfol",
"outputId": "05ea6db7-7fcf-4b74-a763-5f117c14cc00"
},
"outputs": [],
"source": [
"scripted_predictor.save(\"scripted_predictor.pt\")\n",
"\n",
"scripted_predictor = torch.jit.load(\"scripted_predictor.pt\")\n",
"res1 = scripted_predictor(batch)\n",
"\n",
"for i, p in enumerate(res1):\n",
" print(\"Scripted predictor: {label})\".format(label=labels[str(p.item())]))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Data reading and decoding functions also support torch script and therefore can be part of the model as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AnotherPredictor(Predictor):\n",
"\n",
" def forward(self, path: str) -> int:\n",
" with torch.no_grad():\n",
" x = read_image(path).unsqueeze(0)\n",
" x = self.transforms(x)\n",
" y_pred = self.resnet18(x)\n",
" return int(y_pred.argmax(dim=1).item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-cMwTs3Yjffy"
},
"outputs": [],
"source": [
"scripted_predictor2 = torch.jit.script(AnotherPredictor())\n",
"\n",
"res = scripted_predictor2(\"test-image.jpg\")\n",
"\n",
"print(\"Scripted another predictor: {label})\".format(label=labels[str(res)]))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "torchvision_scriptable_transforms.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Welcome to torchvision's new video API\n",
"\n",
"Here, we're going to examine the capabilities of the new video API, together with the examples on how to build datasets and more. \n",
"\n",
"### Table of contents\n",
"1. Introduction: building a new video object and examining the properties\n",
"2. Building a sample `read_video` function\n",
"3. Building an example dataset (can be applied to e.g. kinetics400)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('1.7.0a0+f5c95d5', '0.8.0a0+a2f405d')"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch, torchvision\n",
"torch.__version__, torchvision.__version__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true to ./WUzgd7C1pWA.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.4%"
]
}
],
"source": [
"# download the sample video\n",
"from torchvision.datasets.utils import download_url\n",
"download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true\", \".\", \"WUzgd7C1pWA.mp4\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Introduction: building a new video object and examining the properties\n",
"\n",
"First we select a video to test the object out. For the sake of argument we're using one from Kinetics400 dataset. To create it, we need to define the path and the stream we want to use. See inline comments for description. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch, torchvision\n",
"\"\"\"\n",
"chosen video statistics:\n",
"WUzgd7C1pWA.mp4\n",
" - source: kinetics-400\n",
" - video: H-264 - MPEG-4 AVC (part 10) (avc1)\n",
" - fps: 29.97\n",
" - audio: MPEG AAC audio (mp4a)\n",
" - sample rate: 48K Hz\n",
"\"\"\"\n",
"video_path = \"./WUzgd7C1pWA.mp4\"\n",
"\n",
"\"\"\"\n",
"streams are defined in a similar fashion as torch devices. We encode them as strings in a form\n",
"of `stream_type:stream_id` where stream_type is a string and stream_id a long int. \n",
"\n",
"The constructor accepts passing a stream_type only, in which case the stream is auto-discovered.\n",
"\"\"\"\n",
"stream = \"video\"\n",
"\n",
"\n",
"\n",
"video = torchvision.io.VideoReader(video_path, stream)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's get the metadata for our particular video:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'video': {'duration': [10.9109], 'fps': [29.97002997002997]},\n",
" 'audio': {'duration': [10.9], 'framerate': [48000.0]}}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"video.get_metadata()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we can see that video has two streams - a video and an audio stream. \n",
"\n",
"Let's read all the frames from the video stream."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of frames: 327\n",
"We can expect approx: 327.0\n",
"Tensor size: torch.Size([3, 256, 340])\n"
]
}
],
"source": [
"# first we select the video stream \n",
"metadata = video.get_metadata()\n",
"video.set_current_stream(\"video:0\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"for frame, pts in video:\n",
" frames.append(frame)\n",
" \n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['video']['duration'][0] * metadata['video']['fps'][0]\n",
"print(\"We can expect approx: \", approx_nf)\n",
"print(\"Tensor size: \", frames[0].size())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that selecting zero video stream is equivalent to selecting video stream automatically. I.e. `video:0` and `video` will end up with same results in this case. \n",
"\n",
"Let's try this for audio"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of frames: 511\n",
"Approx total number of datapoints we can expect: 523200.0\n",
"Read data size: 523264\n"
]
}
],
"source": [
"metadata = video.get_metadata()\n",
"video.set_current_stream(\"audio\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"for frame, pts in video:\n",
" frames.append(frame)\n",
" \n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]\n",
"print(\"Approx total number of datapoints we can expect: \", approx_nf)\n",
"print(\"Read data size: \", frames[0].size(0) * len(frames))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But what if we only want to read certain time segment of the video?\n",
"\n",
"That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds. Given that our implementation relies on python iterators, we can leverage `itertools` to simplify the process and make it more pythonic. \n",
"\n",
"For example, if we wanted to read ten frames from second second:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of frames: 10\n"
]
}
],
"source": [
"import itertools\n",
"video.set_current_stream(\"video\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"\n",
"# we seek into a second second of the video\n",
"# and use islice to get 10 frames since\n",
"for frame, pts in itertools.islice(video.seek(2), 10):\n",
" frames.append(frame)\n",
" \n",
"print(\"Total number of frames: \", len(frames))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or if we wanted to read from 2nd to 5th second:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of frames: 90\n",
"We can expect approx: 89.91008991008991\n",
"Tensor size: torch.Size([3, 256, 340])\n"
]
}
],
"source": [
"video.set_current_stream(\"video\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"\n",
"# we seek into a second second of the video\n",
"video = video.seek(2)\n",
"# then we utilize the itertools takewhile to get the \n",
"# correct number of frames\n",
"for frame, pts in itertools.takewhile(lambda x: x[1] <= 5, video):\n",
" frames.append(frame)\n",
"\n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]\n",
"print(\"We can expect approx: \", approx_nf)\n",
"print(\"Tensor size: \", frames[0].size())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Building a sample `read_video` function\n",
"\n",
"We can utilize the methods above to build the read video function that follows the same API to the existing `read_video` function "
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True):\n",
"\n",
" if end is None:\n",
" end = float(\"inf\")\n",
" if end < start:\n",
" raise ValueError(\n",
" \"end time should be larger than start time, got \"\n",
" \"start time={} and end time={}\".format(s, e)\n",
" )\n",
" \n",
" video_frames = torch.empty(0)\n",
" video_pts = []\n",
" if read_video:\n",
" video_object.set_current_stream(\"video\")\n",
" frames = []\n",
" for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" frames.append(t)\n",
" video_pts.append(pts)\n",
" if len(frames) > 0:\n",
" video_frames = torch.stack(frames, 0)\n",
"\n",
" audio_frames = torch.empty(0)\n",
" audio_pts = []\n",
" if read_audio:\n",
" video_object.set_current_stream(\"audio\")\n",
" frames = []\n",
" for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" frames.append(t)\n",
" video_pts.append(pts)\n",
" if len(frames) > 0:\n",
" audio_frames = torch.cat(frames, 0)\n",
"\n",
" return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([327, 3, 256, 340]) torch.Size([523264, 1])\n"
]
}
],
"source": [
"vf, af, info, meta = example_read_video(video)\n",
"# total number of frames should be 327 for video and 523264 datapoints for audio\n",
"print(vf.size(), af.size())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([523264, 1])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# you can also get the sequence of audio frames as well\n",
"af.size()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Building an example randomly sampled dataset (can be applied to training dataest of kinetics400)\n",
"\n",
"Cool, so now we can use the same principle to make the sample dataset. We suggest trying out iterable dataset for this purpose. \n",
"\n",
"Here, we are going to build\n",
"\n",
"a. an example dataset that reads randomly selected 10 frames of video"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# make sample dataest\n",
"import os\n",
"os.makedirs(\"./dataset\", exist_ok=True)\n",
"os.makedirs(\"./dataset/1\", exist_ok=True)\n",
"os.makedirs(\"./dataset/2\", exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true to ./dataset/1/WUzgd7C1pWA.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.4%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true to ./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"102.5%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/SOX5yA1l24A.mp4?raw=true to ./dataset/2/SOX5yA1l24A.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.9%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true to ./dataset/2/v_SoccerJuggling_g23_c01.avi\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"101.5%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true to ./dataset/2/v_SoccerJuggling_g24_c01.avi\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"101.3%"
]
}
],
"source": [
"# download the videos \n",
"from torchvision.datasets.utils import download_url\n",
"download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true\", \"./dataset/1\", \"WUzgd7C1pWA.mp4\")\n",
"download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true\", \"./dataset/1\", \"RATRACE_wave_f_nm_np1_fr_goo_37.avi\")\n",
"download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/SOX5yA1l24A.mp4?raw=true\", \"./dataset/2\", \"SOX5yA1l24A.mp4\")\n",
"download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true\", \"./dataset/2\", \"v_SoccerJuggling_g23_c01.avi\")\n",
"download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true\", \"./dataset/2\", \"v_SoccerJuggling_g24_c01.avi\")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# housekeeping and utilities\n",
"import os\n",
"import random\n",
"\n",
"import torch\n",
"from torchvision.datasets.folder import make_dataset\n",
"from torchvision import transforms as t\n",
"\n",
"def _find_classes(dir):\n",
" classes = [d.name for d in os.scandir(dir) if d.is_dir()]\n",
" classes.sort()\n",
" class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}\n",
" return classes, class_to_idx\n",
"\n",
"def get_samples(root, extensions=(\".mp4\", \".avi\")):\n",
" _, class_to_idx = _find_classes(root)\n",
" return make_dataset(root, class_to_idx, extensions=extensions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are going to define the dataset and some basic arguments. We asume the structure of the FolderDataset, and add the following parameters:\n",
" \n",
"1. frame transform: with this API, we can chose to apply transforms on every frame of the video\n",
"2. videotransform: equally, we can also apply transform to a 4D tensor\n",
"3. length of the clip: do we want a single or multiple frames?\n",
"\n",
"Note that we actually add `epoch size` as using `IterableDataset` class allows us to naturally oversample clips or images from each video if needed. "
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"class RandomDataset(torch.utils.data.IterableDataset):\n",
" def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):\n",
" super(RandomDataset).__init__()\n",
" \n",
" self.samples = get_samples(root)\n",
" \n",
" # allow for temporal jittering\n",
" if epoch_size is None:\n",
" epoch_size = len(self.samples)\n",
" self.epoch_size = epoch_size\n",
" \n",
" self.clip_len = clip_len # length of a clip in frames\n",
" self.frame_transform = frame_transform # transform for every frame individually\n",
" self.video_transform = video_transform # transform on a video sequence\n",
"\n",
" def __iter__(self):\n",
" for i in range(self.epoch_size):\n",
" # get random sample\n",
" path, target = random.choice(self.samples)\n",
" # get video object\n",
" vid = torchvision.io.VideoReader(path, \"video\")\n",
" metadata = vid.get_metadata()\n",
" video_frames = [] # video frame buffer \n",
" # seek and return frames\n",
" \n",
" max_seek = metadata[\"video\"]['duration'][0] - (self.clip_len / metadata[\"video\"]['fps'][0])\n",
" start = random.uniform(0., max_seek)\n",
" for frame, current_pts in itertools.islice(vid.seek(start), self.clip_len):\n",
" video_frames.append(self.frame_transform(frame))\n",
" # stack it into a tensor\n",
" video = torch.stack(video_frames, 0)\n",
" if self.video_transform:\n",
" video = self.video_transform(video)\n",
" output = {\n",
" 'path': path,\n",
" 'video': video,\n",
" 'target': target,\n",
" 'start': start,\n",
" 'end': current_pts}\n",
" yield output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Given a path of videos in a folder structure, i.e:\n",
"```\n",
"dataset:\n",
" -class 1:\n",
" file 0\n",
" file 1\n",
" ...\n",
" - class 2:\n",
" file 0\n",
" file 1\n",
" ...\n",
" - ...\n",
"```\n",
"We can generate a dataloader and test the dataset. \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"from torchvision import transforms as t\n",
"transforms = [t.Resize((112, 112))]\n",
"frame_transform = t.Compose(transforms)\n",
"\n",
"ds = RandomDataset(\"./dataset\", epoch_size=None, frame_transform=frame_transform)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"loader = DataLoader(ds, batch_size=12)\n",
"d = {\"video\":[], 'start':[], 'end':[], 'tensorsize':[]}\n",
"for b in loader:\n",
" for i in range(len(b['path'])):\n",
" d['video'].append(b['path'][i])\n",
" d['start'].append(b['start'][i].item())\n",
" d['end'].append(b['end'][i].item())\n",
" d['tensorsize'].append(b['video'][i].size())"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'video': ['./dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi'],\n",
" 'start': [8.97932147319667,\n",
" 9.421856461438313,\n",
" 2.1301381796579437,\n",
" 5.514273689529127,\n",
" 0.31979853297913124],\n",
" 'end': [9.5095, 9.943266999999999, 2.635967, 6.0393669999999995, 0.833333],\n",
" 'tensorsize': [torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112])]}"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"## Cleanup\n",
"import os, shutil\n",
"os.remove(\"./WUzgd7C1pWA.mp4\")\n",
"shutil.rmtree(\"./dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
# Optional list of dependencies required by the package
dependencies = ['torch']
# classification
from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3
# segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
deeplabv3_resnet50, deeplabv3_resnet101
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment