Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
fbb4cc54
Unverified
Commit
fbb4cc54
authored
Sep 25, 2023
by
Philip Meier
Committed by
GitHub
Sep 25, 2023
Browse files
remove torchvision.prototype module and related tests / CI from release branch (#7983)
parent
a90e5846
Changes
71
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
0 additions
and
1564 deletions
+0
-1564
torchvision/prototype/models/depth/stereo/raft_stereo.py
torchvision/prototype/models/depth/stereo/raft_stereo.py
+0
-843
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+0
-6
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+0
-205
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+0
-134
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+0
-68
torchvision/prototype/transforms/_presets.py
torchvision/prototype/transforms/_presets.py
+0
-80
torchvision/prototype/transforms/_type_conversion.py
torchvision/prototype/transforms/_type_conversion.py
+0
-29
torchvision/prototype/tv_tensors/__init__.py
torchvision/prototype/tv_tensors/__init__.py
+0
-1
torchvision/prototype/tv_tensors/_label.py
torchvision/prototype/tv_tensors/_label.py
+0
-71
torchvision/prototype/utils/__init__.py
torchvision/prototype/utils/__init__.py
+0
-1
torchvision/prototype/utils/_internal.py
torchvision/prototype/utils/_internal.py
+0
-126
No files found.
torchvision/prototype/models/depth/stereo/raft_stereo.py
deleted
100644 → 0
View file @
a90e5846
from
functools
import
partial
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision.models.optical_flow.raft
as
raft
from
torch
import
Tensor
from
torchvision.models._api
import
register_model
,
Weights
,
WeightsEnum
from
torchvision.models._utils
import
handle_legacy_interface
from
torchvision.models.optical_flow._utils
import
grid_sample
,
make_coords_grid
,
upsample_flow
from
torchvision.models.optical_flow.raft
import
FlowHead
,
MotionEncoder
,
ResidualBlock
from
torchvision.ops
import
Conv2dNormActivation
from
torchvision.prototype.transforms._presets
import
StereoMatching
from
torchvision.utils
import
_log_api_usage_once
__all__
=
(
"RaftStereo"
,
"raft_stereo_base"
,
"raft_stereo_realtime"
,
"Raft_Stereo_Base_Weights"
,
"Raft_Stereo_Realtime_Weights"
,
)
class
BaseEncoder
(
raft
.
FeatureEncoder
):
"""Base encoder for FeatureEncoder and ContextEncoder in which weight may be shared.
See the Raft-Stereo paper section 4.6 on backbone part.
"""
def
__init__
(
self
,
*
,
block
:
Callable
[...,
nn
.
Module
]
=
ResidualBlock
,
layers
:
Tuple
[
int
,
int
,
int
,
int
]
=
(
64
,
64
,
96
,
128
),
strides
:
Tuple
[
int
,
int
,
int
,
int
]
=
(
2
,
1
,
2
,
2
),
norm_layer
:
Callable
[...,
nn
.
Module
]
=
nn
.
BatchNorm2d
,
):
# We use layers + (256,) because raft.FeatureEncoder require 5 layers
# but here we will set the last conv layer to identity
super
().
__init__
(
block
=
block
,
layers
=
layers
+
(
256
,),
strides
=
strides
,
norm_layer
=
norm_layer
)
# Base encoder don't have the last conv layer of feature encoder
self
.
conv
=
nn
.
Identity
()
self
.
output_dim
=
layers
[
3
]
num_downsampling
=
sum
([
x
-
1
for
x
in
strides
])
self
.
downsampling_ratio
=
2
**
(
num_downsampling
)
class
FeatureEncoder
(
nn
.
Module
):
"""Feature Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Context Encoder.
The FeatureEncoder takes concatenation of left and right image as input. It produces feature embedding that later
will be used to construct correlation volume.
"""
def
__init__
(
self
,
base_encoder
:
BaseEncoder
,
output_dim
:
int
=
256
,
shared_base
:
bool
=
False
,
block
:
Callable
[...,
nn
.
Module
]
=
ResidualBlock
,
):
super
().
__init__
()
self
.
base_encoder
=
base_encoder
self
.
base_downsampling_ratio
=
base_encoder
.
downsampling_ratio
base_dim
=
base_encoder
.
output_dim
if
not
shared_base
:
self
.
residual_block
:
nn
.
Module
=
nn
.
Identity
()
self
.
conv
=
nn
.
Conv2d
(
base_dim
,
output_dim
,
kernel_size
=
1
)
else
:
# If we share base encoder weight for Feature and Context Encoder
# we need to add residual block with InstanceNorm2d and change the kernel size for conv layer
# see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L35-L37
self
.
residual_block
=
block
(
base_dim
,
base_dim
,
norm_layer
=
nn
.
InstanceNorm2d
,
stride
=
1
)
self
.
conv
=
nn
.
Conv2d
(
base_dim
,
output_dim
,
kernel_size
=
3
,
padding
=
1
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
x
=
self
.
base_encoder
(
x
)
x
=
self
.
residual_block
(
x
)
x
=
self
.
conv
(
x
)
return
x
class
MultiLevelContextEncoder
(
nn
.
Module
):
"""Context Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Feature Encoder.
The ContextEncoder takes left image as input, and it outputs concatenated hidden_states and contexts.
In Raft-Stereo we have multi level GRUs and this context encoder will also multi outputs (list of Tensor)
that correspond to each GRUs.
Take note that the length of "out_with_blocks" parameter represent the number of GRU's level.
args:
base_encoder (nn.Module): The base encoder part that can have a shared weight with feature_encoder's
base_encoder because they have same architecture.
out_with_blocks (List[bool]): The length represent the number of GRU's level (length of output), and
if the element is True then the output layer on that position will have additional block
output_dim (int): The dimension of output on each level (default: 256)
block (Callable[..., nn.Module]): The type of basic block used for downsampling and output layer
(default: ResidualBlock)
"""
def
__init__
(
self
,
base_encoder
:
nn
.
Module
,
out_with_blocks
:
List
[
bool
],
output_dim
:
int
=
256
,
block
:
Callable
[...,
nn
.
Module
]
=
ResidualBlock
,
):
super
().
__init__
()
self
.
num_level
=
len
(
out_with_blocks
)
self
.
base_encoder
=
base_encoder
self
.
base_downsampling_ratio
=
base_encoder
.
downsampling_ratio
base_dim
=
base_encoder
.
output_dim
self
.
downsample_and_out_layers
=
nn
.
ModuleList
(
[
nn
.
ModuleDict
(
{
"downsampler"
:
self
.
_make_downsampler
(
block
,
base_dim
,
base_dim
)
if
i
>
0
else
nn
.
Identity
(),
"out_hidden_state"
:
self
.
_make_out_layer
(
base_dim
,
output_dim
//
2
,
with_block
=
out_with_blocks
[
i
],
block
=
block
),
"out_context"
:
self
.
_make_out_layer
(
base_dim
,
output_dim
//
2
,
with_block
=
out_with_blocks
[
i
],
block
=
block
),
}
)
for
i
in
range
(
self
.
num_level
)
]
)
def
_make_out_layer
(
self
,
in_channels
,
out_channels
,
with_block
=
True
,
block
=
ResidualBlock
):
layers
=
[]
if
with_block
:
layers
.
append
(
block
(
in_channels
,
in_channels
,
norm_layer
=
nn
.
BatchNorm2d
,
stride
=
1
))
layers
.
append
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
))
return
nn
.
Sequential
(
*
layers
)
def
_make_downsampler
(
self
,
block
,
in_channels
,
out_channels
):
block1
=
block
(
in_channels
,
out_channels
,
norm_layer
=
nn
.
BatchNorm2d
,
stride
=
2
)
block2
=
block
(
out_channels
,
out_channels
,
norm_layer
=
nn
.
BatchNorm2d
,
stride
=
1
)
return
nn
.
Sequential
(
block1
,
block2
)
def
forward
(
self
,
x
:
Tensor
)
->
List
[
Tensor
]:
x
=
self
.
base_encoder
(
x
)
outs
=
[]
for
layer_dict
in
self
.
downsample_and_out_layers
:
x
=
layer_dict
[
"downsampler"
](
x
)
outs
.
append
(
torch
.
cat
([
layer_dict
[
"out_hidden_state"
](
x
),
layer_dict
[
"out_context"
](
x
)],
dim
=
1
))
return
outs
class
ConvGRU
(
raft
.
ConvGRU
):
"""Convolutional Gru unit."""
# Modified from raft.ConvGRU to accept pre-convolved contexts,
# see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/update.py#L23
def
forward
(
self
,
h
:
Tensor
,
x
:
Tensor
,
context
:
List
[
Tensor
])
->
Tensor
:
# type: ignore[override]
hx
=
torch
.
cat
([
h
,
x
],
dim
=
1
)
z
=
torch
.
sigmoid
(
self
.
convz
(
hx
)
+
context
[
0
])
r
=
torch
.
sigmoid
(
self
.
convr
(
hx
)
+
context
[
1
])
q
=
torch
.
tanh
(
self
.
convq
(
torch
.
cat
([
r
*
h
,
x
],
dim
=
1
))
+
context
[
2
])
h
=
(
1
-
z
)
*
h
+
z
*
q
return
h
class
MultiLevelUpdateBlock
(
nn
.
Module
):
"""The update block which contains the motion encoder and grus
It must expose a ``hidden_dims`` attribute which is the hidden dimension size of its gru blocks
"""
def
__init__
(
self
,
*
,
motion_encoder
:
MotionEncoder
,
hidden_dims
:
List
[
int
]):
super
().
__init__
()
self
.
motion_encoder
=
motion_encoder
# The GRU input size is the size of previous level hidden_dim plus next level hidden_dim
# if this is the first gru, then we replace previous level with motion_encoder output channels
# for the last GRU, we don't add the next level hidden_dim
gru_input_dims
=
[]
for
i
in
range
(
len
(
hidden_dims
)):
input_dim
=
hidden_dims
[
i
-
1
]
if
i
>
0
else
motion_encoder
.
out_channels
if
i
<
len
(
hidden_dims
)
-
1
:
input_dim
+=
hidden_dims
[
i
+
1
]
gru_input_dims
.
append
(
input_dim
)
self
.
grus
=
nn
.
ModuleList
(
[
ConvGRU
(
input_size
=
gru_input_dims
[
i
],
hidden_size
=
hidden_dims
[
i
],
kernel_size
=
3
,
padding
=
1
)
# Ideally we should reverse the direction during forward to use the gru with the smallest resolution
# first however currently there is no way to reverse a ModuleList that is jit script compatible
# hence we reverse the ordering of self.grus on the constructor instead
# see: https://github.com/pytorch/pytorch/issues/31772
for
i
in
reversed
(
list
(
range
(
len
(
hidden_dims
))))
]
)
self
.
hidden_dims
=
hidden_dims
def
forward
(
self
,
hidden_states
:
List
[
Tensor
],
contexts
:
List
[
List
[
Tensor
]],
corr_features
:
Tensor
,
disparity
:
Tensor
,
level_processed
:
List
[
bool
],
)
->
List
[
Tensor
]:
# We call it reverse_i because it has a reversed ordering compared to hidden_states
# see self.grus on the constructor for more detail
for
reverse_i
,
gru
in
enumerate
(
self
.
grus
):
i
=
len
(
self
.
grus
)
-
1
-
reverse_i
if
level_processed
[
i
]:
# X is concatenation of 2x downsampled hidden_dim (or motion_features if no bigger dim) with
# upsampled hidden_dim (or nothing if not exist).
if
i
==
0
:
features
=
self
.
motion_encoder
(
disparity
,
corr_features
)
else
:
# 2x downsampled features from larger hidden states
features
=
F
.
avg_pool2d
(
hidden_states
[
i
-
1
],
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
if
i
<
len
(
self
.
grus
)
-
1
:
# Concat with 2x upsampled features from smaller hidden states
_
,
_
,
h
,
w
=
hidden_states
[
i
+
1
].
shape
features
=
torch
.
cat
(
[
features
,
F
.
interpolate
(
hidden_states
[
i
+
1
],
size
=
(
2
*
h
,
2
*
w
),
mode
=
"bilinear"
,
align_corners
=
True
),
],
dim
=
1
,
)
hidden_states
[
i
]
=
gru
(
hidden_states
[
i
],
features
,
contexts
[
i
])
# NOTE: For slow-fast gru, we don't always want to calculate delta disparity for every call on UpdateBlock
# Hence we move the delta disparity calculation to the RAFT-Stereo main forward
return
hidden_states
class
MaskPredictor
(
raft
.
MaskPredictor
):
"""Mask predictor to be used when upsampling the predicted disparity."""
# We add out_channels compared to raft.MaskPredictor
def
__init__
(
self
,
*
,
in_channels
:
int
,
hidden_size
:
int
,
out_channels
:
int
,
multiplier
:
float
=
0.25
):
super
(
raft
.
MaskPredictor
,
self
).
__init__
()
self
.
convrelu
=
Conv2dNormActivation
(
in_channels
,
hidden_size
,
norm_layer
=
None
,
kernel_size
=
3
)
self
.
conv
=
nn
.
Conv2d
(
hidden_size
,
out_channels
,
kernel_size
=
1
,
padding
=
0
)
self
.
multiplier
=
multiplier
class
CorrPyramid1d
(
nn
.
Module
):
"""Row-wise correlation pyramid.
Create a row-wise correlation pyramid with ``num_levels`` level from the outputs of the feature encoder,
this correlation pyramid will later be used as index to create correlation features using CorrBlock1d.
"""
def
__init__
(
self
,
num_levels
:
int
=
4
):
super
().
__init__
()
self
.
num_levels
=
num_levels
def
forward
(
self
,
fmap1
:
Tensor
,
fmap2
:
Tensor
)
->
List
[
Tensor
]:
"""Build the correlation pyramid from two feature maps.
The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) on the same row.
The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions
to build the correlation pyramid.
"""
torch
.
_assert
(
fmap1
.
shape
==
fmap2
.
shape
,
f
"Input feature maps should have the same shape, instead got
{
fmap1
.
shape
}
(fmap1.shape) !=
{
fmap2
.
shape
}
(fmap2.shape)"
,
)
batch_size
,
num_channels
,
h
,
w
=
fmap1
.
shape
fmap1
=
fmap1
.
view
(
batch_size
,
num_channels
,
h
,
w
)
fmap2
=
fmap2
.
view
(
batch_size
,
num_channels
,
h
,
w
)
corr
=
torch
.
einsum
(
"aijk,aijh->ajkh"
,
fmap1
,
fmap2
)
corr
=
corr
.
view
(
batch_size
,
h
,
w
,
1
,
w
)
corr_volume
=
corr
/
torch
.
sqrt
(
torch
.
tensor
(
num_channels
,
device
=
corr
.
device
))
corr_volume
=
corr_volume
.
reshape
(
batch_size
*
h
*
w
,
1
,
1
,
w
)
corr_pyramid
=
[
corr_volume
]
for
_
in
range
(
self
.
num_levels
-
1
):
corr_volume
=
F
.
avg_pool2d
(
corr_volume
,
kernel_size
=
(
1
,
2
),
stride
=
(
1
,
2
))
corr_pyramid
.
append
(
corr_volume
)
return
corr_pyramid
class
CorrBlock1d
(
nn
.
Module
):
"""The row-wise correlation block.
Use indexes from correlation pyramid to create correlation features.
The "indexing" of a given centroid pixel x' is done by concatenating its surrounding row neighbours
within radius
"""
def
__init__
(
self
,
*
,
num_levels
:
int
=
4
,
radius
:
int
=
4
):
super
().
__init__
()
self
.
radius
=
radius
self
.
out_channels
=
num_levels
*
(
2
*
radius
+
1
)
def
forward
(
self
,
centroids_coords
:
Tensor
,
corr_pyramid
:
List
[
Tensor
])
->
Tensor
:
"""Return correlation features by indexing from the pyramid."""
neighborhood_side_len
=
2
*
self
.
radius
+
1
# see note in __init__ about out_channels
di
=
torch
.
linspace
(
-
self
.
radius
,
self
.
radius
,
neighborhood_side_len
,
device
=
centroids_coords
.
device
)
di
=
di
.
view
(
1
,
1
,
neighborhood_side_len
,
1
).
to
(
centroids_coords
.
device
)
batch_size
,
_
,
h
,
w
=
centroids_coords
.
shape
# _ = 2 but we only use the first one
# We only consider 1d and take the first dim only
centroids_coords
=
centroids_coords
[:,
:
1
].
permute
(
0
,
2
,
3
,
1
).
reshape
(
batch_size
*
h
*
w
,
1
,
1
,
1
)
indexed_pyramid
=
[]
for
corr_volume
in
corr_pyramid
:
x0
=
centroids_coords
+
di
# end shape is (batch_size * h * w, 1, side_len, 1)
y0
=
torch
.
zeros_like
(
x0
)
sampling_coords
=
torch
.
cat
([
x0
,
y0
],
dim
=-
1
)
indexed_corr_volume
=
grid_sample
(
corr_volume
,
sampling_coords
,
align_corners
=
True
,
mode
=
"bilinear"
).
view
(
batch_size
,
h
,
w
,
-
1
)
indexed_pyramid
.
append
(
indexed_corr_volume
)
centroids_coords
=
centroids_coords
/
2
corr_features
=
torch
.
cat
(
indexed_pyramid
,
dim
=-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
expected_output_shape
=
(
batch_size
,
self
.
out_channels
,
h
,
w
)
torch
.
_assert
(
corr_features
.
shape
==
expected_output_shape
,
f
"Output shape of index pyramid is incorrect. Should be
{
expected_output_shape
}
, got
{
corr_features
.
shape
}
"
,
)
return
corr_features
class
RaftStereo
(
nn
.
Module
):
def
__init__
(
self
,
*
,
feature_encoder
:
FeatureEncoder
,
context_encoder
:
MultiLevelContextEncoder
,
corr_pyramid
:
CorrPyramid1d
,
corr_block
:
CorrBlock1d
,
update_block
:
MultiLevelUpdateBlock
,
disparity_head
:
nn
.
Module
,
mask_predictor
:
Optional
[
nn
.
Module
]
=
None
,
slow_fast
:
bool
=
False
,
):
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
args:
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``.
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``.
It has multi-level output and each level will have 2 parts:
- one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
- one part will be used to initialize the hidden state of the recurrent unit of
the ``update_block``
corr_pyramid (CorrPyramid1d): Module to build the correlation pyramid from feature encoder output
corr_block (CorrBlock1d): The correlation block, which uses the correlation pyramid indexes
to create correlation features. It takes the coordinate of the centroid pixel and correlation pyramid
as input and returns the correlation features.
It must expose an ``out_channels`` attribute.
update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit.
It takes as input the hidden state of its recurrent unit, the context, the correlation
features, and the current predicted disparity. It outputs an updated hidden state
disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity.
mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
If ``None`` (default), the flow is upsampled using interpolation.
slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper
on section 3.4 for more detail.
"""
super
().
__init__
()
_log_api_usage_once
(
self
)
# This indicates that the disparity output will be only have 1 channel (represent horizontal axis).
# We need this because some stereo matching model like CREStereo might have 2 channel on the output
self
.
output_channels
=
1
self
.
feature_encoder
=
feature_encoder
self
.
context_encoder
=
context_encoder
self
.
base_downsampling_ratio
=
feature_encoder
.
base_downsampling_ratio
self
.
num_level
=
self
.
context_encoder
.
num_level
self
.
corr_pyramid
=
corr_pyramid
self
.
corr_block
=
corr_block
self
.
update_block
=
update_block
self
.
disparity_head
=
disparity_head
self
.
mask_predictor
=
mask_predictor
hidden_dims
=
self
.
update_block
.
hidden_dims
# Follow the original implementation to do pre convolution on the context
# See: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L32
self
.
context_convs
=
nn
.
ModuleList
(
[
nn
.
Conv2d
(
hidden_dims
[
i
],
hidden_dims
[
i
]
*
3
,
kernel_size
=
3
,
padding
=
1
)
for
i
in
range
(
self
.
num_level
)]
)
self
.
slow_fast
=
slow_fast
def
forward
(
self
,
left_image
:
Tensor
,
right_image
:
Tensor
,
flow_init
:
Optional
[
Tensor
]
=
None
,
num_iters
:
int
=
12
)
->
List
[
Tensor
]:
"""
Return disparity predictions on every iteration as a list of Tensor.
args:
left_image (Tensor): The input left image with layout B, C, H, W
right_image (Tensor): The input right image with layout B, C, H, W
flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None
num_iters (int): Number of update block iteration on the largest resolution. Default: 12
"""
batch_size
,
_
,
h
,
w
=
left_image
.
shape
torch
.
_assert
(
(
h
,
w
)
==
right_image
.
shape
[
-
2
:],
f
"input images should have the same shape, instead got (
{
h
}
,
{
w
}
) !=
{
right_image
.
shape
[
-
2
:]
}
"
,
)
torch
.
_assert
(
(
h
%
self
.
base_downsampling_ratio
==
0
and
w
%
self
.
base_downsampling_ratio
==
0
),
f
"input image H and W should be divisible by
{
self
.
base_downsampling_ratio
}
, instead got H=
{
h
}
and W=
{
w
}
"
,
)
fmaps
=
self
.
feature_encoder
(
torch
.
cat
([
left_image
,
right_image
],
dim
=
0
))
fmap1
,
fmap2
=
torch
.
chunk
(
fmaps
,
chunks
=
2
,
dim
=
0
)
torch
.
_assert
(
fmap1
.
shape
[
-
2
:]
==
(
h
//
self
.
base_downsampling_ratio
,
w
//
self
.
base_downsampling_ratio
),
f
"The feature encoder should downsample H and W by
{
self
.
base_downsampling_ratio
}
"
,
)
corr_pyramid
=
self
.
corr_pyramid
(
fmap1
,
fmap2
)
# Multi level contexts
context_outs
=
self
.
context_encoder
(
left_image
)
hidden_dims
=
self
.
update_block
.
hidden_dims
context_out_channels
=
[
context_outs
[
i
].
shape
[
1
]
-
hidden_dims
[
i
]
for
i
in
range
(
len
(
context_outs
))]
hidden_states
:
List
[
Tensor
]
=
[]
contexts
:
List
[
List
[
Tensor
]]
=
[]
for
i
,
context_conv
in
enumerate
(
self
.
context_convs
):
# As in the original paper, the actual output of the context encoder is split in 2 parts:
# - one part is used to initialize the hidden state of the recurent units of the update block
# - the rest is the "actual" context.
hidden_state
,
context
=
torch
.
split
(
context_outs
[
i
],
[
hidden_dims
[
i
],
context_out_channels
[
i
]],
dim
=
1
)
hidden_states
.
append
(
torch
.
tanh
(
hidden_state
))
contexts
.
append
(
# mypy is technically correct here. The return type of `torch.split` was incorrectly annotated with
# `List[int]` although it should have been `Tuple[Tensor, ...]`. However, the latter is not supported by
# JIT and thus we have to keep the wrong annotation here and silence mypy.
torch
.
split
(
# type: ignore[arg-type]
context_conv
(
F
.
relu
(
context
)),
[
hidden_dims
[
i
],
hidden_dims
[
i
],
hidden_dims
[
i
]],
dim
=
1
)
)
_
,
Cf
,
Hf
,
Wf
=
fmap1
.
shape
coords0
=
make_coords_grid
(
batch_size
,
Hf
,
Wf
).
to
(
fmap1
.
device
)
coords1
=
make_coords_grid
(
batch_size
,
Hf
,
Wf
).
to
(
fmap1
.
device
)
# We use flow_init for cascade inference
if
flow_init
is
not
None
:
coords1
=
coords1
+
flow_init
disparity_predictions
=
[]
for
_
in
range
(
num_iters
):
coords1
=
coords1
.
detach
()
# Don't backpropagate gradients through this branch, see paper
corr_features
=
self
.
corr_block
(
centroids_coords
=
coords1
,
corr_pyramid
=
corr_pyramid
)
disparity
=
coords1
-
coords0
if
self
.
slow_fast
:
# Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often
for
i
in
range
(
1
,
self
.
num_level
):
# We only processed the smallest i levels
level_processed
=
[
False
]
*
(
self
.
num_level
-
i
)
+
[
True
]
*
i
hidden_states
=
self
.
update_block
(
hidden_states
,
contexts
,
corr_features
,
disparity
,
level_processed
=
level_processed
)
hidden_states
=
self
.
update_block
(
hidden_states
,
contexts
,
corr_features
,
disparity
,
level_processed
=
[
True
]
*
self
.
num_level
)
# Take the largest hidden_state to get the disparity
hidden_state
=
hidden_states
[
0
]
delta_disparity
=
self
.
disparity_head
(
hidden_state
)
# in stereo mode, project disparity onto epipolar
delta_disparity
[:,
1
]
=
0.0
coords1
=
coords1
+
delta_disparity
up_mask
=
None
if
self
.
mask_predictor
is
None
else
self
.
mask_predictor
(
hidden_state
)
upsampled_disparity
=
upsample_flow
(
(
coords1
-
coords0
),
up_mask
=
up_mask
,
factor
=
self
.
base_downsampling_ratio
)
disparity_predictions
.
append
(
upsampled_disparity
[:,
:
1
])
return
disparity_predictions
def
_raft_stereo
(
*
,
weights
:
Optional
[
WeightsEnum
],
progress
:
bool
,
shared_encoder_weight
:
bool
,
# Feature encoder
feature_encoder_layers
:
Tuple
[
int
,
int
,
int
,
int
,
int
],
feature_encoder_strides
:
Tuple
[
int
,
int
,
int
,
int
],
feature_encoder_block
:
Callable
[...,
nn
.
Module
],
# Context encoder
context_encoder_layers
:
Tuple
[
int
,
int
,
int
,
int
,
int
],
context_encoder_strides
:
Tuple
[
int
,
int
,
int
,
int
],
# if the `out_with_blocks` param of the context_encoder is True, then
# the particular output on that level position will have additional `context_encoder_block` layer
context_encoder_out_with_blocks
:
List
[
bool
],
context_encoder_block
:
Callable
[...,
nn
.
Module
],
# Correlation block
corr_num_levels
:
int
,
corr_radius
:
int
,
# Motion encoder
motion_encoder_corr_layers
:
Tuple
[
int
,
int
],
motion_encoder_flow_layers
:
Tuple
[
int
,
int
],
motion_encoder_out_channels
:
int
,
# Update block
update_block_hidden_dims
:
List
[
int
],
# Flow Head
flow_head_hidden_size
:
int
,
# Mask predictor
mask_predictor_hidden_size
:
int
,
use_mask_predictor
:
bool
,
slow_fast
:
bool
,
**
kwargs
,
):
if
len
(
context_encoder_out_with_blocks
)
!=
len
(
update_block_hidden_dims
):
raise
ValueError
(
"Length of context_encoder_out_with_blocks and update_block_hidden_dims must be the same"
+
"because both of them represent the number of GRUs level"
)
if
shared_encoder_weight
:
if
(
feature_encoder_layers
[:
-
1
]
!=
context_encoder_layers
[:
-
1
]
or
feature_encoder_strides
!=
context_encoder_strides
):
raise
ValueError
(
"If shared_encoder_weight is True, then the feature_encoder_layers[:-1]"
+
" and feature_encoder_strides must be the same with context_encoder_layers[:-1] and context_encoder_strides!"
)
base_encoder
=
kwargs
.
pop
(
"base_encoder"
,
None
)
or
BaseEncoder
(
block
=
context_encoder_block
,
layers
=
context_encoder_layers
[:
-
1
],
strides
=
context_encoder_strides
,
norm_layer
=
nn
.
BatchNorm2d
,
)
feature_base_encoder
=
base_encoder
context_base_encoder
=
base_encoder
else
:
feature_base_encoder
=
BaseEncoder
(
block
=
feature_encoder_block
,
layers
=
feature_encoder_layers
[:
-
1
],
strides
=
feature_encoder_strides
,
norm_layer
=
nn
.
InstanceNorm2d
,
)
context_base_encoder
=
BaseEncoder
(
block
=
context_encoder_block
,
layers
=
context_encoder_layers
[:
-
1
],
strides
=
context_encoder_strides
,
norm_layer
=
nn
.
BatchNorm2d
,
)
feature_encoder
=
kwargs
.
pop
(
"feature_encoder"
,
None
)
or
FeatureEncoder
(
feature_base_encoder
,
output_dim
=
feature_encoder_layers
[
-
1
],
shared_base
=
shared_encoder_weight
,
block
=
feature_encoder_block
,
)
context_encoder
=
kwargs
.
pop
(
"context_encoder"
,
None
)
or
MultiLevelContextEncoder
(
context_base_encoder
,
out_with_blocks
=
context_encoder_out_with_blocks
,
output_dim
=
context_encoder_layers
[
-
1
],
block
=
context_encoder_block
,
)
feature_downsampling_ratio
=
feature_encoder
.
base_downsampling_ratio
corr_pyramid
=
kwargs
.
pop
(
"corr_pyramid"
,
None
)
or
CorrPyramid1d
(
num_levels
=
corr_num_levels
)
corr_block
=
kwargs
.
pop
(
"corr_block"
,
None
)
or
CorrBlock1d
(
num_levels
=
corr_num_levels
,
radius
=
corr_radius
)
motion_encoder
=
kwargs
.
pop
(
"motion_encoder"
,
None
)
or
MotionEncoder
(
in_channels_corr
=
corr_block
.
out_channels
,
corr_layers
=
motion_encoder_corr_layers
,
flow_layers
=
motion_encoder_flow_layers
,
out_channels
=
motion_encoder_out_channels
,
)
update_block
=
kwargs
.
pop
(
"update_block"
,
None
)
or
MultiLevelUpdateBlock
(
motion_encoder
=
motion_encoder
,
hidden_dims
=
update_block_hidden_dims
)
# We use the largest scale hidden_dims of update_block to get the predicted disparity
disparity_head
=
kwargs
.
pop
(
"disparity_head"
,
None
)
or
FlowHead
(
in_channels
=
update_block_hidden_dims
[
0
],
hidden_size
=
flow_head_hidden_size
,
)
mask_predictor
=
kwargs
.
pop
(
"mask_predictor"
,
None
)
if
use_mask_predictor
:
mask_predictor
=
MaskPredictor
(
in_channels
=
update_block
.
hidden_dims
[
0
],
hidden_size
=
mask_predictor_hidden_size
,
out_channels
=
9
*
feature_downsampling_ratio
*
feature_downsampling_ratio
,
)
else
:
mask_predictor
=
None
model
=
RaftStereo
(
feature_encoder
=
feature_encoder
,
context_encoder
=
context_encoder
,
corr_pyramid
=
corr_pyramid
,
corr_block
=
corr_block
,
update_block
=
update_block
,
disparity_head
=
disparity_head
,
mask_predictor
=
mask_predictor
,
slow_fast
=
slow_fast
,
**
kwargs
,
# not really needed, all params should be consumed by now
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
class
Raft_Stereo_Realtime_Weights
(
WeightsEnum
):
SCENEFLOW_V1
=
Weights
(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url
=
"https://download.pytorch.org/models/raft_stereo_realtime-cf345ccb.pth"
,
transforms
=
partial
(
StereoMatching
,
resize_size
=
(
224
,
224
)),
meta
=
{
"num_params"
:
8077152
,
"recipe"
:
"https://github.com/princeton-vl/RAFT-Stereo"
,
"_metrics"
:
{
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Kitty2015"
:
{
"3px"
:
0.9409
,
}
},
},
)
DEFAULT
=
SCENEFLOW_V1
class
Raft_Stereo_Base_Weights
(
WeightsEnum
):
SCENEFLOW_V1
=
Weights
(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url
=
"https://download.pytorch.org/models/raft_stereo_base_sceneflow-eff3f2e6.pth"
,
transforms
=
partial
(
StereoMatching
,
resize_size
=
(
224
,
224
)),
meta
=
{
"num_params"
:
11116176
,
"recipe"
:
"https://github.com/princeton-vl/RAFT-Stereo"
,
"_metrics"
:
{
# Following metrics from paper: https://arxiv.org/abs/2109.07547
# Using standard metrics for each dataset
"Kitty2015"
:
{
# Ratio of pixels with difference less than 3px from ground truth
"3px"
:
0.9426
,
},
# For middlebury, ratio of pixels with difference less than 2px from ground truth
# on full, half, and quarter image resolution
"Middlebury2014-val-full"
:
{
"2px"
:
0.8167
,
},
"Middlebury2014-val-half"
:
{
"2px"
:
0.8741
,
},
"Middlebury2014-val-quarter"
:
{
"2px"
:
0.9064
,
},
"ETH3D-val"
:
{
# Ratio of pixels with difference less than 1px from ground truth
"1px"
:
0.9672
,
},
},
},
)
MIDDLEBURY_V1
=
Weights
(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url
=
"https://download.pytorch.org/models/raft_stereo_base_middlebury-afa9d252.pth"
,
transforms
=
partial
(
StereoMatching
,
resize_size
=
(
224
,
224
)),
meta
=
{
"num_params"
:
11116176
,
"recipe"
:
"https://github.com/princeton-vl/RAFT-Stereo"
,
"_metrics"
:
{
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Middlebury-test"
:
{
"mae"
:
1.27
,
"1px"
:
0.9063
,
"2px"
:
0.9526
,
"5px"
:
0.9725
,
}
},
},
)
ETH3D_V1
=
Weights
(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url
=
"https://download.pytorch.org/models/raft_stereo_base_eth3d-d4830f22.pth"
,
transforms
=
partial
(
StereoMatching
,
resize_size
=
(
224
,
224
)),
meta
=
{
"num_params"
:
11116176
,
"recipe"
:
"https://github.com/princeton-vl/RAFT-Stereo"
,
"_metrics"
:
{
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"ETH3D-test"
:
{
"mae"
:
0.18
,
"1px"
:
0.9756
,
"2px"
:
0.9956
,
}
},
},
)
DEFAULT
=
MIDDLEBURY_V1
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
None
))
def
raft_stereo_realtime
(
*
,
weights
:
Optional
[
Raft_Stereo_Realtime_Weights
]
=
None
,
progress
=
True
,
**
kwargs
)
->
RaftStereo
:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
This is the realtime variant of the Raft-Stereo model that is described on the paper section 4.7.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights
:members:
"""
weights
=
Raft_Stereo_Realtime_Weights
.
verify
(
weights
)
return
_raft_stereo
(
weights
=
weights
,
progress
=
progress
,
shared_encoder_weight
=
True
,
# Feature encoder
feature_encoder_layers
=
(
64
,
64
,
96
,
128
,
256
),
feature_encoder_strides
=
(
2
,
1
,
2
,
2
),
feature_encoder_block
=
ResidualBlock
,
# Context encoder
context_encoder_layers
=
(
64
,
64
,
96
,
128
,
256
),
context_encoder_strides
=
(
2
,
1
,
2
,
2
),
context_encoder_out_with_blocks
=
[
True
,
True
],
context_encoder_block
=
ResidualBlock
,
# Correlation block
corr_num_levels
=
4
,
corr_radius
=
4
,
# Motion encoder
motion_encoder_corr_layers
=
(
64
,
64
),
motion_encoder_flow_layers
=
(
64
,
64
),
motion_encoder_out_channels
=
128
,
# Update block
update_block_hidden_dims
=
[
128
,
128
],
# Flow head
flow_head_hidden_size
=
256
,
# Mask predictor
mask_predictor_hidden_size
=
256
,
use_mask_predictor
=
True
,
slow_fast
=
True
,
**
kwargs
,
)
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
None
))
def
raft_stereo_base
(
*
,
weights
:
Optional
[
Raft_Stereo_Base_Weights
]
=
None
,
progress
=
True
,
**
kwargs
)
->
RaftStereo
:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights
:members:
"""
weights
=
Raft_Stereo_Base_Weights
.
verify
(
weights
)
return
_raft_stereo
(
weights
=
weights
,
progress
=
progress
,
shared_encoder_weight
=
False
,
# Feature encoder
feature_encoder_layers
=
(
64
,
64
,
96
,
128
,
256
),
feature_encoder_strides
=
(
1
,
1
,
2
,
2
),
feature_encoder_block
=
ResidualBlock
,
# Context encoder
context_encoder_layers
=
(
64
,
64
,
96
,
128
,
256
),
context_encoder_strides
=
(
1
,
1
,
2
,
2
),
context_encoder_out_with_blocks
=
[
True
,
True
,
False
],
context_encoder_block
=
ResidualBlock
,
# Correlation block
corr_num_levels
=
4
,
corr_radius
=
4
,
# Motion encoder
motion_encoder_corr_layers
=
(
64
,
64
),
motion_encoder_flow_layers
=
(
64
,
64
),
motion_encoder_out_channels
=
128
,
# Update block
update_block_hidden_dims
=
[
128
,
128
,
128
],
# Flow head
flow_head_hidden_size
=
256
,
# Mask predictor
mask_predictor_hidden_size
=
256
,
use_mask_predictor
=
True
,
slow_fast
=
False
,
**
kwargs
,
)
torchvision/prototype/transforms/__init__.py
deleted
100644 → 0
View file @
a90e5846
from
._presets
import
StereoMatching
# usort: skip
from
._augment
import
SimpleCopyPaste
from
._geometry
import
FixedSizeCrop
from
._misc
import
PermuteDimensions
,
TransposeDimensions
from
._type_conversion
import
LabelToOneHot
torchvision/prototype/transforms/_augment.py
deleted
100644 → 0
View file @
a90e5846
from
typing
import
Any
,
cast
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
torch
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
tv_tensors
from
torchvision.ops
import
masks_to_boxes
from
torchvision.prototype
import
tv_tensors
as
proto_tv_tensors
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2._utils
import
is_pure_tensor
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
class
SimpleCopyPaste
(
Transform
):
def
__init__
(
self
,
blending
:
bool
=
True
,
resize_interpolation
:
Union
[
int
,
InterpolationMode
]
=
F
.
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
bool
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
resize_interpolation
=
_check_interpolation
(
resize_interpolation
)
self
.
blending
=
blending
self
.
antialias
=
antialias
def
_copy_paste
(
self
,
image
:
Union
[
torch
.
Tensor
,
tv_tensors
.
Image
],
target
:
Dict
[
str
,
Any
],
paste_image
:
Union
[
torch
.
Tensor
,
tv_tensors
.
Image
],
paste_target
:
Dict
[
str
,
Any
],
random_selection
:
torch
.
Tensor
,
blending
:
bool
,
resize_interpolation
:
F
.
InterpolationMode
,
antialias
:
Optional
[
bool
],
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
Any
]]:
paste_masks
=
tv_tensors
.
wrap
(
paste_target
[
"masks"
][
random_selection
],
like
=
paste_target
[
"masks"
])
paste_boxes
=
tv_tensors
.
wrap
(
paste_target
[
"boxes"
][
random_selection
],
like
=
paste_target
[
"boxes"
])
paste_labels
=
tv_tensors
.
wrap
(
paste_target
[
"labels"
][
random_selection
],
like
=
paste_target
[
"labels"
])
masks
=
target
[
"masks"
]
# We resize source and paste data if they have different sizes
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1
=
cast
(
List
[
int
],
image
.
shape
[
-
2
:])
size2
=
paste_image
.
shape
[
-
2
:]
if
size1
!=
size2
:
paste_image
=
F
.
resize
(
paste_image
,
size
=
size1
,
interpolation
=
resize_interpolation
,
antialias
=
antialias
)
paste_masks
=
F
.
resize
(
paste_masks
,
size
=
size1
)
paste_boxes
=
F
.
resize
(
paste_boxes
,
size
=
size1
)
paste_alpha_mask
=
paste_masks
.
sum
(
dim
=
0
)
>
0
if
blending
:
paste_alpha_mask
=
F
.
gaussian_blur
(
paste_alpha_mask
.
unsqueeze
(
0
),
kernel_size
=
[
5
,
5
],
sigma
=
[
2.0
])
inverse_paste_alpha_mask
=
paste_alpha_mask
.
logical_not
()
# Copy-paste images:
image
=
image
.
mul
(
inverse_paste_alpha_mask
).
add_
(
paste_image
.
mul
(
paste_alpha_mask
))
# Copy-paste masks:
masks
=
masks
*
inverse_paste_alpha_mask
non_all_zero_masks
=
masks
.
sum
((
-
1
,
-
2
))
>
0
masks
=
masks
[
non_all_zero_masks
]
# Do a shallow copy of the target dict
out_target
=
{
k
:
v
for
k
,
v
in
target
.
items
()}
out_target
[
"masks"
]
=
torch
.
cat
([
masks
,
paste_masks
])
# Copy-paste boxes and labels
bbox_format
=
target
[
"boxes"
].
format
xyxy_boxes
=
masks_to_boxes
(
masks
)
# masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2.
# There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes
[:,
2
:]
+=
1
boxes
=
F
.
convert_bounding_box_format
(
xyxy_boxes
,
old_format
=
tv_tensors
.
BoundingBoxFormat
.
XYXY
,
new_format
=
bbox_format
,
inplace
=
True
)
out_target
[
"boxes"
]
=
torch
.
cat
([
boxes
,
paste_boxes
])
labels
=
target
[
"labels"
][
non_all_zero_masks
]
out_target
[
"labels"
]
=
torch
.
cat
([
labels
,
paste_labels
])
# Check for degenerated boxes and remove them
boxes
=
F
.
convert_bounding_box_format
(
out_target
[
"boxes"
],
old_format
=
bbox_format
,
new_format
=
tv_tensors
.
BoundingBoxFormat
.
XYXY
)
degenerate_boxes
=
boxes
[:,
2
:]
<=
boxes
[:,
:
2
]
if
degenerate_boxes
.
any
():
valid_targets
=
~
degenerate_boxes
.
any
(
dim
=
1
)
out_target
[
"boxes"
]
=
boxes
[
valid_targets
]
out_target
[
"masks"
]
=
out_target
[
"masks"
][
valid_targets
]
out_target
[
"labels"
]
=
out_target
[
"labels"
][
valid_targets
]
return
image
,
out_target
def
_extract_image_targets
(
self
,
flat_sample
:
List
[
Any
]
)
->
Tuple
[
List
[
Union
[
torch
.
Tensor
,
tv_tensors
.
Image
]],
List
[
Dict
[
str
,
Any
]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images
,
bboxes
,
masks
,
labels
=
[],
[],
[],
[]
for
obj
in
flat_sample
:
if
isinstance
(
obj
,
tv_tensors
.
Image
)
or
is_pure_tensor
(
obj
):
images
.
append
(
obj
)
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
images
.
append
(
F
.
to_image
(
obj
))
elif
isinstance
(
obj
,
tv_tensors
.
BoundingBoxes
):
bboxes
.
append
(
obj
)
elif
isinstance
(
obj
,
tv_tensors
.
Mask
):
masks
.
append
(
obj
)
elif
isinstance
(
obj
,
(
proto_tv_tensors
.
Label
,
proto_tv_tensors
.
OneHotLabel
)):
labels
.
append
(
obj
)
if
not
(
len
(
images
)
==
len
(
bboxes
)
==
len
(
masks
)
==
len
(
labels
)):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain equal sized list of Images, "
"BoundingBoxeses, Masks and Labels or OneHotLabels."
)
targets
=
[]
for
bbox
,
mask
,
label
in
zip
(
bboxes
,
masks
,
labels
):
targets
.
append
({
"boxes"
:
bbox
,
"masks"
:
mask
,
"labels"
:
label
})
return
images
,
targets
def
_insert_outputs
(
self
,
flat_sample
:
List
[
Any
],
output_images
:
List
[
torch
.
Tensor
],
output_targets
:
List
[
Dict
[
str
,
Any
]],
)
->
None
:
c0
,
c1
,
c2
,
c3
=
0
,
0
,
0
,
0
for
i
,
obj
in
enumerate
(
flat_sample
):
if
isinstance
(
obj
,
tv_tensors
.
Image
):
flat_sample
[
i
]
=
tv_tensors
.
wrap
(
output_images
[
c0
],
like
=
obj
)
c0
+=
1
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
flat_sample
[
i
]
=
F
.
to_pil_image
(
output_images
[
c0
])
c0
+=
1
elif
is_pure_tensor
(
obj
):
flat_sample
[
i
]
=
output_images
[
c0
]
c0
+=
1
elif
isinstance
(
obj
,
tv_tensors
.
BoundingBoxes
):
flat_sample
[
i
]
=
tv_tensors
.
wrap
(
output_targets
[
c1
][
"boxes"
],
like
=
obj
)
c1
+=
1
elif
isinstance
(
obj
,
tv_tensors
.
Mask
):
flat_sample
[
i
]
=
tv_tensors
.
wrap
(
output_targets
[
c2
][
"masks"
],
like
=
obj
)
c2
+=
1
elif
isinstance
(
obj
,
(
proto_tv_tensors
.
Label
,
proto_tv_tensors
.
OneHotLabel
)):
flat_sample
[
i
]
=
tv_tensors
.
wrap
(
output_targets
[
c3
][
"labels"
],
like
=
obj
)
c3
+=
1
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
images
,
targets
=
self
.
_extract_image_targets
(
flat_inputs
)
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled
=
images
[
-
1
:]
+
images
[:
-
1
]
targets_rolled
=
targets
[
-
1
:]
+
targets
[:
-
1
]
output_images
,
output_targets
=
[],
[]
for
image
,
target
,
paste_image
,
paste_target
in
zip
(
images
,
targets
,
images_rolled
,
targets_rolled
):
# Random paste targets selection:
num_masks
=
len
(
paste_target
[
"masks"
])
if
num_masks
<
1
:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
output_image
,
output_target
=
image
,
target
else
:
random_selection
=
torch
.
randint
(
0
,
num_masks
,
(
num_masks
,),
device
=
paste_image
.
device
)
random_selection
=
torch
.
unique
(
random_selection
)
output_image
,
output_target
=
self
.
_copy_paste
(
image
,
target
,
paste_image
,
paste_target
,
random_selection
=
random_selection
,
blending
=
self
.
blending
,
resize_interpolation
=
self
.
resize_interpolation
,
antialias
=
self
.
antialias
,
)
output_images
.
append
(
output_image
)
output_targets
.
append
(
output_target
)
# Insert updated images and targets into input flat_sample
self
.
_insert_outputs
(
flat_inputs
,
output_images
,
output_targets
)
return
tree_unflatten
(
flat_inputs
,
spec
)
torchvision/prototype/transforms/_geometry.py
deleted
100644 → 0
View file @
a90e5846
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
import
PIL.Image
import
torch
from
torchvision
import
tv_tensors
from
torchvision.prototype.tv_tensors
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
(
_FillType
,
_get_fill
,
_setup_fill_arg
,
_setup_size
,
get_bounding_boxes
,
has_any
,
is_pure_tensor
,
query_size
,
)
class
FixedSizeCrop
(
Transform
):
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Union
[
_FillType
,
Dict
[
Union
[
Type
,
str
],
_FillType
]]
=
0
,
padding_mode
:
str
=
"constant"
,
)
->
None
:
super
().
__init__
()
size
=
tuple
(
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
))
self
.
crop_height
=
size
[
0
]
self
.
crop_width
=
size
[
1
]
self
.
fill
=
fill
self
.
_fill
=
_setup_fill_arg
(
fill
)
self
.
padding_mode
=
padding_mode
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
not
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
tv_tensors
.
Image
,
is_pure_tensor
,
tv_tensors
.
Video
,
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain an tensor or PIL image or a Video."
)
if
has_any
(
flat_inputs
,
tv_tensors
.
BoundingBoxes
)
and
not
has_any
(
flat_inputs
,
Label
,
OneHotLabel
):
raise
TypeError
(
f
"If a BoundingBoxes is contained in the input sample, "
f
"
{
type
(
self
).
__name__
}
() also requires it to contain a Label or OneHotLabel."
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
height
,
width
=
query_size
(
flat_inputs
)
new_height
=
min
(
height
,
self
.
crop_height
)
new_width
=
min
(
width
,
self
.
crop_width
)
needs_crop
=
new_height
!=
height
or
new_width
!=
width
offset_height
=
max
(
height
-
self
.
crop_height
,
0
)
offset_width
=
max
(
width
-
self
.
crop_width
,
0
)
r
=
torch
.
rand
(
1
)
top
=
int
(
offset_height
*
r
)
left
=
int
(
offset_width
*
r
)
bounding_boxes
:
Optional
[
torch
.
Tensor
]
try
:
bounding_boxes
=
get_bounding_boxes
(
flat_inputs
)
except
ValueError
:
bounding_boxes
=
None
if
needs_crop
and
bounding_boxes
is
not
None
:
format
=
bounding_boxes
.
format
bounding_boxes
,
canvas_size
=
F
.
crop_bounding_boxes
(
bounding_boxes
.
as_subclass
(
torch
.
Tensor
),
format
=
format
,
top
=
top
,
left
=
left
,
height
=
new_height
,
width
=
new_width
,
)
bounding_boxes
=
F
.
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas_size
)
height_and_width
=
F
.
convert_bounding_box_format
(
bounding_boxes
,
old_format
=
format
,
new_format
=
tv_tensors
.
BoundingBoxFormat
.
XYWH
)[...,
2
:]
is_valid
=
torch
.
all
(
height_and_width
>
0
,
dim
=-
1
)
else
:
is_valid
=
None
pad_bottom
=
max
(
self
.
crop_height
-
new_height
,
0
)
pad_right
=
max
(
self
.
crop_width
-
new_width
,
0
)
needs_pad
=
pad_bottom
!=
0
or
pad_right
!=
0
return
dict
(
needs_crop
=
needs_crop
,
top
=
top
,
left
=
left
,
height
=
new_height
,
width
=
new_width
,
is_valid
=
is_valid
,
padding
=
[
0
,
0
,
pad_right
,
pad_bottom
],
needs_pad
=
needs_pad
,
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_crop"
]:
inpt
=
self
.
_call_kernel
(
F
.
crop
,
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
],
)
if
params
[
"is_valid"
]
is
not
None
:
if
isinstance
(
inpt
,
(
Label
,
OneHotLabel
,
tv_tensors
.
Mask
)):
inpt
=
tv_tensors
.
wrap
(
inpt
[
params
[
"is_valid"
]],
like
=
inpt
)
elif
isinstance
(
inpt
,
tv_tensors
.
BoundingBoxes
):
inpt
=
tv_tensors
.
wrap
(
F
.
clamp_bounding_boxes
(
inpt
[
params
[
"is_valid"
]],
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
),
like
=
inpt
,
)
if
params
[
"needs_pad"
]:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
inpt
=
self
.
_call_kernel
(
F
.
pad
,
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
return
inpt
torchvision/prototype/transforms/_misc.py
deleted
100644 → 0
View file @
a90e5846
import
functools
import
warnings
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
import
torch
from
torchvision
import
tv_tensors
from
torchvision.transforms.v2
import
Transform
from
torchvision.transforms.v2._utils
import
is_pure_tensor
T
=
TypeVar
(
"T"
)
def
_default_arg
(
value
:
T
)
->
T
:
return
value
def
_get_defaultdict
(
default
:
T
)
->
Dict
[
Any
,
T
]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return
defaultdict
(
functools
.
partial
(
_default_arg
,
default
))
class
PermuteDimensions
(
Transform
):
_transformed_types
=
(
is_pure_tensor
,
tv_tensors
.
Image
,
tv_tensors
.
Video
)
def
__init__
(
self
,
dims
:
Union
[
Sequence
[
int
],
Dict
[
Type
,
Optional
[
Sequence
[
int
]]]])
->
None
:
super
().
__init__
()
if
not
isinstance
(
dims
,
dict
):
dims
=
_get_defaultdict
(
dims
)
if
torch
.
Tensor
in
dims
and
any
(
cls
in
dims
for
cls
in
[
tv_tensors
.
Image
,
tv_tensors
.
Video
]):
warnings
.
warn
(
"Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self
.
dims
=
dims
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
dims
=
self
.
dims
[
type
(
inpt
)]
if
dims
is
None
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
return
inpt
.
permute
(
*
dims
)
class
TransposeDimensions
(
Transform
):
_transformed_types
=
(
is_pure_tensor
,
tv_tensors
.
Image
,
tv_tensors
.
Video
)
def
__init__
(
self
,
dims
:
Union
[
Tuple
[
int
,
int
],
Dict
[
Type
,
Optional
[
Tuple
[
int
,
int
]]]])
->
None
:
super
().
__init__
()
if
not
isinstance
(
dims
,
dict
):
dims
=
_get_defaultdict
(
dims
)
if
torch
.
Tensor
in
dims
and
any
(
cls
in
dims
for
cls
in
[
tv_tensors
.
Image
,
tv_tensors
.
Video
]):
warnings
.
warn
(
"Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self
.
dims
=
dims
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
dims
=
self
.
dims
[
type
(
inpt
)]
if
dims
is
None
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
return
inpt
.
transpose
(
*
dims
)
torchvision/prototype/transforms/_presets.py
deleted
100644 → 0
View file @
a90e5846
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
torch
from
torch
import
Tensor
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
__all__
=
[
"StereoMatching"
]
class
StereoMatching
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
*
,
use_gray_scale
:
bool
=
False
,
resize_size
:
Optional
[
Tuple
[
int
,
...]],
mean
:
Tuple
[
float
,
...]
=
(
0.5
,
0.5
,
0.5
),
std
:
Tuple
[
float
,
...]
=
(
0.5
,
0.5
,
0.5
),
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
)
->
None
:
super
().
__init__
()
# pacify mypy
self
.
resize_size
:
Union
[
None
,
List
]
if
resize_size
is
not
None
:
self
.
resize_size
=
list
(
resize_size
)
else
:
self
.
resize_size
=
None
self
.
mean
=
list
(
mean
)
self
.
std
=
list
(
std
)
self
.
interpolation
=
_check_interpolation
(
interpolation
)
self
.
use_gray_scale
=
use_gray_scale
def
forward
(
self
,
left_image
:
Tensor
,
right_image
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
def
_process_image
(
img
:
PIL
.
Image
.
Image
)
->
Tensor
:
if
not
isinstance
(
img
,
Tensor
):
img
=
F
.
pil_to_tensor
(
img
)
if
self
.
resize_size
is
not
None
:
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the stereo models with antialias=True?
img
=
F
.
resize
(
img
,
self
.
resize_size
,
interpolation
=
self
.
interpolation
,
antialias
=
False
)
if
self
.
use_gray_scale
is
True
:
img
=
F
.
rgb_to_grayscale
(
img
)
img
=
F
.
convert_image_dtype
(
img
,
torch
.
float
)
img
=
F
.
normalize
(
img
,
mean
=
self
.
mean
,
std
=
self
.
std
)
img
=
img
.
contiguous
()
return
img
left_image
=
_process_image
(
left_image
)
right_image
=
_process_image
(
right_image
)
return
left_image
,
right_image
def
__repr__
(
self
)
->
str
:
format_string
=
self
.
__class__
.
__name__
+
"("
format_string
+=
f
"
\n
resize_size=
{
self
.
resize_size
}
"
format_string
+=
f
"
\n
mean=
{
self
.
mean
}
"
format_string
+=
f
"
\n
std=
{
self
.
std
}
"
format_string
+=
f
"
\n
interpolation=
{
self
.
interpolation
}
"
format_string
+=
"
\n
)"
return
format_string
def
describe
(
self
)
->
str
:
return
(
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f
"The images are resized to ``resize_size=
{
self
.
resize_size
}
`` using ``interpolation=
{
self
.
interpolation
}
``. "
f
"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean=
{
self
.
mean
}
`` and "
f
"``std=
{
self
.
std
}
``."
)
torchvision/prototype/transforms/_type_conversion.py
deleted
100644 → 0
View file @
a90e5846
from
typing
import
Any
,
Dict
import
torch
from
torch.nn.functional
import
one_hot
from
torchvision.prototype
import
tv_tensors
as
proto_tv_tensors
from
torchvision.transforms.v2
import
Transform
class
LabelToOneHot
(
Transform
):
_transformed_types
=
(
proto_tv_tensors
.
Label
,)
def
__init__
(
self
,
num_categories
:
int
=
-
1
):
super
().
__init__
()
self
.
num_categories
=
num_categories
def
_transform
(
self
,
inpt
:
proto_tv_tensors
.
Label
,
params
:
Dict
[
str
,
Any
])
->
proto_tv_tensors
.
OneHotLabel
:
num_categories
=
self
.
num_categories
if
num_categories
==
-
1
and
inpt
.
categories
is
not
None
:
num_categories
=
len
(
inpt
.
categories
)
output
=
one_hot
(
inpt
.
as_subclass
(
torch
.
Tensor
),
num_classes
=
num_categories
)
return
proto_tv_tensors
.
OneHotLabel
(
output
,
categories
=
inpt
.
categories
)
def
extra_repr
(
self
)
->
str
:
if
self
.
num_categories
==
-
1
:
return
""
return
f
"num_categories=
{
self
.
num_categories
}
"
torchvision/prototype/tv_tensors/__init__.py
deleted
100644 → 0
View file @
a90e5846
from
._label
import
Label
,
OneHotLabel
torchvision/prototype/tv_tensors/_label.py
deleted
100644 → 0
View file @
a90e5846
from
__future__
import
annotations
from
typing
import
Any
,
Optional
,
Sequence
,
Type
,
TypeVar
,
Union
import
torch
from
torch.utils._pytree
import
tree_map
from
torchvision.tv_tensors._tv_tensor
import
TVTensor
L
=
TypeVar
(
"L"
,
bound
=
"_LabelBase"
)
class
_LabelBase
(
TVTensor
):
categories
:
Optional
[
Sequence
[
str
]]
@
classmethod
def
_wrap
(
cls
:
Type
[
L
],
tensor
:
torch
.
Tensor
,
*
,
categories
:
Optional
[
Sequence
[
str
]])
->
L
:
label_base
=
tensor
.
as_subclass
(
cls
)
label_base
.
categories
=
categories
return
label_base
def
__new__
(
cls
:
Type
[
L
],
data
:
Any
,
*
,
categories
:
Optional
[
Sequence
[
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
Optional
[
bool
]
=
None
,
)
->
L
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
cls
.
_wrap
(
tensor
,
categories
=
categories
)
@
classmethod
def
from_category
(
cls
:
Type
[
L
],
category
:
str
,
*
,
categories
:
Sequence
[
str
],
**
kwargs
:
Any
,
)
->
L
:
return
cls
(
categories
.
index
(
category
),
categories
=
categories
,
**
kwargs
)
class
Label
(
_LabelBase
):
def
to_categories
(
self
)
->
Any
:
if
self
.
categories
is
None
:
raise
RuntimeError
(
"Label does not have categories"
)
return
tree_map
(
lambda
idx
:
self
.
categories
[
idx
],
self
.
tolist
())
class
OneHotLabel
(
_LabelBase
):
def
__new__
(
cls
,
data
:
Any
,
*
,
categories
:
Optional
[
Sequence
[
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
)
->
OneHotLabel
:
one_hot_label
=
super
().
__new__
(
cls
,
data
,
categories
=
categories
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
categories
is
not
None
and
len
(
categories
)
!=
one_hot_label
.
shape
[
-
1
]:
raise
ValueError
()
return
one_hot_label
torchvision/prototype/utils/__init__.py
deleted
100644 → 0
View file @
a90e5846
from
.
import
_internal
torchvision/prototype/utils/_internal.py
deleted
100644 → 0
View file @
a90e5846
import
collections.abc
import
difflib
import
io
import
mmap
import
platform
from
typing
import
BinaryIO
,
Callable
,
Collection
,
Sequence
,
TypeVar
,
Union
import
numpy
as
np
import
torch
from
torchvision._utils
import
sequence_to_str
__all__
=
[
"add_suggestion"
,
"fromfile"
,
"ReadOnlyTensorBuffer"
,
]
def
add_suggestion
(
msg
:
str
,
*
,
word
:
str
,
possibilities
:
Collection
[
str
],
close_match_hint
:
Callable
[[
str
],
str
]
=
lambda
close_match
:
f
"Did you mean '
{
close_match
}
'?"
,
alternative_hint
:
Callable
[
[
Sequence
[
str
]],
str
]
=
lambda
possibilities
:
f
"Can be
{
sequence_to_str
(
possibilities
,
separate_last
=
'or '
)
}
."
,
)
->
str
:
if
not
isinstance
(
possibilities
,
collections
.
abc
.
Sequence
):
possibilities
=
sorted
(
possibilities
)
suggestions
=
difflib
.
get_close_matches
(
word
,
possibilities
,
1
)
hint
=
close_match_hint
(
suggestions
[
0
])
if
suggestions
else
alternative_hint
(
possibilities
)
if
not
hint
:
return
msg
return
f
"
{
msg
.
strip
()
}
{
hint
}
"
D
=
TypeVar
(
"D"
)
def
_read_mutable_buffer_fallback
(
file
:
BinaryIO
,
count
:
int
,
item_size
:
int
)
->
bytearray
:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return
bytearray
(
file
.
read
(
-
1
if
count
==
-
1
else
count
*
item_size
))
def
fromfile
(
file
:
BinaryIO
,
*
,
dtype
:
torch
.
dtype
,
byte_order
:
str
,
count
:
int
=
-
1
,
)
->
torch
.
Tensor
:
"""Construct a tensor from a binary file.
.. note::
This function is similar to :func:`numpy.fromfile` with two notable differences:
1. This function only accepts an open binary file, but not a path to it.
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
concept.
.. note::
If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.
Args:
file (IO): Open binary file.
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
"""
byte_order
=
"<"
if
byte_order
==
"little"
else
">"
char
=
"f"
if
dtype
.
is_floating_point
else
(
"i"
if
dtype
.
is_signed
else
"u"
)
item_size
=
(
torch
.
finfo
if
dtype
.
is_floating_point
else
torch
.
iinfo
)(
dtype
).
bits
//
8
np_dtype
=
byte_order
+
char
+
str
(
item_size
)
buffer
:
Union
[
memoryview
,
bytearray
]
if
platform
.
system
()
!=
"Windows"
:
# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
# to a mutable location afterwards.
try
:
buffer
=
memoryview
(
mmap
.
mmap
(
file
.
fileno
(),
0
))[
file
.
tell
()
:]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file
.
seek
(
*
(
0
,
io
.
SEEK_END
)
if
count
==
-
1
else
(
count
*
item_size
,
io
.
SEEK_CUR
))
except
(
AttributeError
,
PermissionError
,
io
.
UnsupportedOperation
):
buffer
=
_read_mutable_buffer_fallback
(
file
,
count
,
item_size
)
else
:
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
buffer
=
_read_mutable_buffer_fallback
(
file
,
count
,
item_size
)
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
# successive .astype() call.
return
torch
.
from_numpy
(
np
.
frombuffer
(
buffer
,
dtype
=
np_dtype
,
count
=
count
).
astype
(
np_dtype
[
1
:],
copy
=
False
))
class
ReadOnlyTensorBuffer
:
def
__init__
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
self
.
_memory
=
memoryview
(
tensor
.
numpy
())
self
.
_cursor
:
int
=
0
def
tell
(
self
)
->
int
:
return
self
.
_cursor
def
seek
(
self
,
offset
:
int
,
whence
:
int
=
io
.
SEEK_SET
)
->
int
:
if
whence
==
io
.
SEEK_SET
:
self
.
_cursor
=
offset
elif
whence
==
io
.
SEEK_CUR
:
self
.
_cursor
+=
offset
pass
elif
whence
==
io
.
SEEK_END
:
self
.
_cursor
=
len
(
self
.
_memory
)
+
offset
else
:
raise
ValueError
(
f
"'whence' should be ``
{
io
.
SEEK_SET
}
``, ``
{
io
.
SEEK_CUR
}
``, or ``
{
io
.
SEEK_END
}
``, "
f
"but got
{
repr
(
whence
)
}
instead"
)
return
self
.
tell
()
def
read
(
self
,
size
:
int
=
-
1
)
->
bytes
:
cursor
=
self
.
tell
()
offset
,
whence
=
(
0
,
io
.
SEEK_END
)
if
size
==
-
1
else
(
size
,
io
.
SEEK_CUR
)
return
self
.
_memory
[
slice
(
cursor
,
self
.
seek
(
offset
,
whence
))].
tobytes
()
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment