Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
cc26cd81
Commit
cc26cd81
authored
Nov 27, 2023
by
panning
Browse files
merge v0.16.0
parents
f78f29f5
fbb4cc54
Changes
370
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2327 additions
and
145 deletions
+2327
-145
torchvision/csrc/ops/cpu/roi_align_kernel.cpp
torchvision/csrc/ops/cpu/roi_align_kernel.cpp
+1
-1
torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
+14
-18
torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu
torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu
+33
-9
torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu
torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu
+11
-6
torchvision/csrc/ops/cuda/roi_align_kernel.cu
torchvision/csrc/ops/cuda/roi_align_kernel.cu
+34
-15
torchvision/csrc/ops/cuda/roi_pool_kernel.cu
torchvision/csrc/ops/cuda/roi_pool_kernel.cu
+16
-10
torchvision/csrc/ops/mps/mps_helpers.h
torchvision/csrc/ops/mps/mps_helpers.h
+6
-0
torchvision/csrc/ops/mps/mps_kernels.h
torchvision/csrc/ops/mps/mps_kernels.h
+1102
-0
torchvision/csrc/ops/mps/nms_kernel.mm
torchvision/csrc/ops/mps/nms_kernel.mm
+109
-0
torchvision/csrc/ops/mps/ps_roi_align_kernel.mm
torchvision/csrc/ops/mps/ps_roi_align_kernel.mm
+205
-0
torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm
torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm
+200
-0
torchvision/csrc/ops/mps/roi_align_kernel.mm
torchvision/csrc/ops/mps/roi_align_kernel.mm
+197
-0
torchvision/csrc/ops/mps/roi_pool_kernel.mm
torchvision/csrc/ops/mps/roi_pool_kernel.mm
+196
-0
torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp
torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp
+1
-1
torchvision/csrc/ops/roi_align.cpp
torchvision/csrc/ops/roi_align.cpp
+57
-2
torchvision/csrc/ops/roi_align.h
torchvision/csrc/ops/roi_align.h
+22
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+15
-0
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+43
-25
torchvision/datasets/_stereo_matching.py
torchvision/datasets/_stereo_matching.py
+61
-54
torchvision/datasets/celeba.py
torchvision/datasets/celeba.py
+4
-4
No files found.
Too many changes to show.
To preserve performance only
370 of 370+
files are displayed.
Plain diff
Email patch
torchvision/csrc/ops/cpu/roi_align_kernel.cpp
View file @
cc26cd81
...
...
@@ -60,7 +60,7 @@ void roi_align_forward_kernel_impl(
// When the grid is empty, output zeros.
const
T
count
=
std
::
max
(
roi_bin_grid_h
*
roi_bin_grid_w
,
1
);
// e.g. = 4
// we want to precalculate indices and weights shared by all chanels,
// we want to precalculate indices and weights shared by all chan
n
els,
// this is the key point of optimization
std
::
vector
<
detail
::
PreCalc
<
T
>>
pre_calc
(
roi_bin_grid_h
*
roi_bin_grid_w
*
pooled_width
*
pooled_height
);
...
...
torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
View file @
cc26cd81
...
...
@@ -70,7 +70,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/
cuda/Atomic
.cuh>
#include <ATen/
native/cuda/KernelUtils
.cuh>
#include "cuda_helpers.h"
...
...
@@ -300,11 +300,7 @@ void deformable_im2col(
data_col
.
data_ptr
<
scalar_t
>
());
}));
}
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in deformable_im2col: %s
\n
"
,
cudaGetErrorString
(
err
));
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
int
get_greatest_divisor_below_bound
(
int
n
,
int
bound
)
{
...
...
@@ -339,6 +335,8 @@ __global__ void deformable_col2im_kernel(
index_t
out_w
,
bool
use_mask
,
scalar_t
*
grad_im
)
{
const
index_t
grad_im_numel
=
width
*
height
*
channels
*
batch_sz
;
CUDA_1D_KERNEL_LOOP_T
(
index
,
n
,
int64_t
)
{
const
index_t
out_x
=
index
%
out_w
;
const
index_t
out_y
=
(
index
/
out_w
)
%
out_h
;
...
...
@@ -385,7 +383,12 @@ __global__ void deformable_col2im_kernel(
std
::
abs
(
y
-
yp
)
<
1
&&
std
::
abs
(
x
-
xp
)
<
1
)
{
index_t
grad_pos
=
((
b
*
channels
+
c
)
*
height
+
yp
)
*
width
+
xp
;
scalar_t
weight
=
(
1
-
std
::
abs
(
y
-
yp
))
*
(
1
-
std
::
abs
(
x
-
xp
));
gpuAtomicAdd
(
grad_im
+
grad_pos
,
mask_value
*
weight
*
col
[
index
]);
at
::
native
::
fastAtomicAdd
(
grad_im
,
grad_pos
,
grad_im_numel
,
mask_value
*
weight
*
col
[
index
],
true
);
}
}
}
...
...
@@ -430,6 +433,8 @@ void compute_grad_input(
// Checks if num_kernels or columns numel larger than 2 ** 31
use_64bits_indexing
|=
num_kernels
>
(
1
<<
31
);
at
::
globalContext
().
alertNotDeterministic
(
"compute_grad_input"
);
if
(
use_64bits_indexing
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
columns
.
scalar_type
(),
"compute_grad_input"
,
([
&
]
{
...
...
@@ -483,11 +488,7 @@ void compute_grad_input(
grad_im
.
data_ptr
<
scalar_t
>
());
}));
}
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in compute_grad_input: %s
\n
"
,
cudaGetErrorString
(
err
));
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
typename
scalar_t
,
typename
index_t
>
...
...
@@ -736,12 +737,7 @@ void compute_grad_offset_and_mask(
grad_mask
.
data_ptr
<
scalar_t
>
());
}));
}
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in compute_grad_offset_and_mask: %s
\n
"
,
cudaGetErrorString
(
err
));
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
backward_gradient_inputs
(
...
...
torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu
View file @
cc26cd81
...
...
@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/
cuda/Atomic
.cuh>
#include <ATen/
native/cuda/KernelUtils
.cuh>
#include "cuda_helpers.h"
...
...
@@ -212,7 +212,8 @@ __global__ void ps_roi_align_backward_kernel_impl(
int
sampling_ratio
,
int
channels_out
,
T
*
grad_input
,
const
T
*
rois
)
{
const
T
*
rois
,
const
int
memory_span
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
// (n, *, ph, pw) is an element in the pooled output
int
pw
=
index
%
pooled_width
;
...
...
@@ -235,8 +236,6 @@ __global__ void ps_roi_align_backward_kernel_impl(
T
bin_size_w
=
roi_width
/
static_cast
<
T
>
(
pooled_width
);
int
c_in
=
channel_mapping
[
index
];
T
*
grad_input_offset
=
grad_input
+
(
roi_batch_ind
*
channels
+
c_in
)
*
height
*
width
;
// Do not using floor/ceil; this implementation detail is critical
T
hstart
=
static_cast
<
T
>
(
ph
)
*
bin_size_h
+
roi_start_h
;
...
...
@@ -252,6 +251,8 @@ __global__ void ps_roi_align_backward_kernel_impl(
(
sampling_ratio
>
0
)
?
sampling_ratio
:
ceil
(
roi_width
/
pooled_width
);
const
T
count
=
roi_bin_grid_h
*
roi_bin_grid_w
;
const
int
offset
=
(
roi_batch_ind
*
channels
+
c_in
)
*
height
*
width
;
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
{
const
T
y
=
hstart
+
static_cast
<
T
>
(
iy
+
.5
f
)
*
bin_size_h
/
...
...
@@ -285,10 +286,30 @@ __global__ void ps_roi_align_backward_kernel_impl(
T
g4
=
grad_output_this_bin
*
w4
/
count
;
if
(
x_low
>=
0
&&
x_high
>=
0
&&
y_low
>=
0
&&
y_high
>=
0
)
{
gpuAtomicAdd
(
grad_input_offset
+
y_low
*
width
+
x_low
,
g1
);
gpuAtomicAdd
(
grad_input_offset
+
y_low
*
width
+
x_high
,
g2
);
gpuAtomicAdd
(
grad_input_offset
+
y_high
*
width
+
x_low
,
g3
);
gpuAtomicAdd
(
grad_input_offset
+
y_high
*
width
+
x_high
,
g4
);
at
::
native
::
fastAtomicAdd
(
grad_input
,
offset
+
y_low
*
width
+
x_low
,
memory_span
,
static_cast
<
T
>
(
g1
),
true
);
at
::
native
::
fastAtomicAdd
(
grad_input
,
offset
+
y_low
*
width
+
x_high
,
memory_span
,
static_cast
<
T
>
(
g2
),
true
);
at
::
native
::
fastAtomicAdd
(
grad_input
,
offset
+
y_high
*
width
+
x_low
,
memory_span
,
static_cast
<
T
>
(
g3
),
true
);
at
::
native
::
fastAtomicAdd
(
grad_input
,
offset
+
y_high
*
width
+
x_high
,
memory_span
,
static_cast
<
T
>
(
g4
),
true
);
}
// if
}
// ix
}
// iy
...
...
@@ -412,6 +433,8 @@ at::Tensor ps_roi_align_backward_kernel(
int
channels_out
=
channels
/
(
pooled_height
*
pooled_width
);
at
::
globalContext
().
alertNotDeterministic
(
"ps_roi_align_backward_kernel"
);
auto
grad_
=
grad
.
contiguous
(),
rois_
=
rois
.
contiguous
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad
.
scalar_type
(),
"ps_roi_align_backward_kernel"
,
[
&
]
{
...
...
@@ -428,7 +451,8 @@ at::Tensor ps_roi_align_backward_kernel(
sampling_ratio
,
channels_out
,
grad_input
.
data_ptr
<
scalar_t
>
(),
rois_
.
data_ptr
<
scalar_t
>
());
rois_
.
data_ptr
<
scalar_t
>
(),
grad_input
.
numel
());
});
AT_CUDA_CHECK
(
cudaGetLastError
());
return
grad_input
;
...
...
torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu
View file @
cc26cd81
...
...
@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/
cuda/Atomic
.cuh>
#include <ATen/
native/cuda/KernelUtils
.cuh>
#include "cuda_helpers.h"
...
...
@@ -91,7 +91,8 @@ __global__ void ps_roi_pool_backward_kernel_impl(
int
pooled_width
,
int
channels_out
,
T
*
grad_input
,
const
T
*
rois
)
{
const
T
*
rois
,
const
int
memory_span
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
// (n, *, ph, pw) is an element in the pooled output
int
pw
=
index
%
pooled_width
;
...
...
@@ -124,14 +125,15 @@ __global__ void ps_roi_pool_backward_kernel_impl(
bool
is_empty
=
(
hend
<=
hstart
)
||
(
wend
<=
wstart
);
int
c_in
=
channel_mapping
[
index
];
T
*
grad_input_offset
=
grad_input
+
(
roi_batch_ind
*
channels
+
c_in
)
*
height
*
width
;
T
bin_area
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
T
diff_val
=
is_empty
?
static_cast
<
T
>
(
0
)
:
grad_output
[
index
]
/
bin_area
;
const
int
offset
=
(
roi_batch_ind
*
channels
+
c_in
)
*
height
*
width
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
int
grad_input_index
=
h
*
width
+
w
;
gpuAtomicAdd
(
grad_input_offset
+
grad_input_index
,
diff_val
);
at
::
native
::
fastAtomicAdd
(
grad_input
,
offset
+
grad_input_index
,
memory_span
,
diff_val
,
true
);
}
}
}
...
...
@@ -251,6 +253,8 @@ at::Tensor ps_roi_pool_backward_kernel(
int
channels_out
=
channels
/
(
pooled_height
*
pooled_width
);
at
::
globalContext
().
alertNotDeterministic
(
"ps_roi_pool_backward_kernel"
);
auto
grad_
=
grad
.
contiguous
(),
rois_
=
rois
.
contiguous
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad
.
scalar_type
(),
"ps_roi_pool_backward_kernel"
,
[
&
]
{
...
...
@@ -267,7 +271,8 @@ at::Tensor ps_roi_pool_backward_kernel(
pooled_width
,
channels_out
,
grad_input
.
data_ptr
<
scalar_t
>
(),
rois_
.
data_ptr
<
scalar_t
>
());
rois_
.
data_ptr
<
scalar_t
>
(),
grad_input
.
numel
());
});
AT_CUDA_CHECK
(
cudaGetLastError
());
return
grad_input
;
...
...
torchvision/csrc/ops/cuda/roi_align_kernel.cu
View file @
cc26cd81
...
...
@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/
cuda/Atomic
.cuh>
#include <ATen/
native/cuda/KernelUtils
.cuh>
#include "cuda_helpers.h"
...
...
@@ -218,7 +218,8 @@ __global__ void roi_align_backward_kernel_impl(
int
n_stride
,
int
c_stride
,
int
h_stride
,
int
w_stride
)
{
int
w_stride
,
const
int
memory_span
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
// (n, c, ph, pw) is an element in the pooled output
int
pw
=
index
%
pooled_width
;
...
...
@@ -247,12 +248,9 @@ __global__ void roi_align_backward_kernel_impl(
T
bin_size_h
=
static_cast
<
T
>
(
roi_height
)
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
static_cast
<
T
>
(
roi_width
)
/
static_cast
<
T
>
(
pooled_width
);
T
*
offset_grad_input
=
grad_input
+
((
roi_batch_ind
*
channels
+
c
)
*
height
*
width
);
// We need to index the gradient using the tensor strides to access the
// correct values.
int
output_offset
=
n
*
n_stride
+
c
*
c_stride
;
const
int
output_offset
=
n
*
n_stride
+
c
*
c_stride
;
const
T
*
offset_grad_output
=
grad_output
+
output_offset
;
const
T
grad_output_this_bin
=
offset_grad_output
[
ph
*
h_stride
+
pw
*
w_stride
];
...
...
@@ -267,6 +265,8 @@ __global__ void roi_align_backward_kernel_impl(
// We do average (integral) pooling inside a bin
const
T
count
=
roi_bin_grid_h
*
roi_bin_grid_w
;
// e.g. = 4
const
int
input_offset
=
(
roi_batch_ind
*
channels
+
c
)
*
height
*
width
;
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
// e.g., iy = 0, 1
{
const
T
y
=
roi_start_h
+
ph
*
bin_size_h
+
...
...
@@ -301,14 +301,30 @@ __global__ void roi_align_backward_kernel_impl(
T
g4
=
grad_output_this_bin
*
w4
/
count
;
if
(
x_low
>=
0
&&
x_high
>=
0
&&
y_low
>=
0
&&
y_high
>=
0
)
{
gpuAtomicAdd
(
offset_grad_input
+
y_low
*
width
+
x_low
,
static_cast
<
T
>
(
g1
));
gpuAtomicAdd
(
offset_grad_input
+
y_low
*
width
+
x_high
,
static_cast
<
T
>
(
g2
));
gpuAtomicAdd
(
offset_grad_input
+
y_high
*
width
+
x_low
,
static_cast
<
T
>
(
g3
));
gpuAtomicAdd
(
offset_grad_input
+
y_high
*
width
+
x_high
,
static_cast
<
T
>
(
g4
));
at
::
native
::
fastAtomicAdd
(
grad_input
,
input_offset
+
y_low
*
width
+
x_low
,
memory_span
,
static_cast
<
T
>
(
g1
),
true
);
at
::
native
::
fastAtomicAdd
(
grad_input
,
input_offset
+
y_low
*
width
+
x_high
,
memory_span
,
static_cast
<
T
>
(
g2
),
true
);
at
::
native
::
fastAtomicAdd
(
grad_input
,
input_offset
+
y_high
*
width
+
x_low
,
memory_span
,
static_cast
<
T
>
(
g3
),
true
);
at
::
native
::
fastAtomicAdd
(
grad_input
,
input_offset
+
y_high
*
width
+
x_high
,
memory_span
,
static_cast
<
T
>
(
g4
),
true
);
}
// if
}
// ix
}
// iy
...
...
@@ -421,6 +437,8 @@ at::Tensor roi_align_backward_kernel(
int
h_stride
=
grad
.
stride
(
2
);
int
w_stride
=
grad
.
stride
(
3
);
at
::
globalContext
().
alertNotDeterministic
(
"roi_align_backward_kernel"
);
auto
rois_
=
rois
.
contiguous
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad
.
scalar_type
(),
"roi_align_backward_kernel"
,
[
&
]
{
...
...
@@ -440,7 +458,8 @@ at::Tensor roi_align_backward_kernel(
n_stride
,
c_stride
,
h_stride
,
w_stride
);
w_stride
,
grad_input
.
numel
());
});
AT_CUDA_CHECK
(
cudaGetLastError
());
return
grad_input
;
...
...
torchvision/csrc/ops/cuda/roi_pool_kernel.cu
View file @
cc26cd81
...
...
@@ -3,7 +3,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <torch/library.h>
#include <ATen/
cuda/Atomic
.cuh>
#include <ATen/
native/cuda/KernelUtils
.cuh>
#include "cuda_helpers.h"
...
...
@@ -94,7 +94,8 @@ __global__ void roi_pool_backward_kernel_impl(
int
n_stride
,
int
c_stride
,
int
h_stride
,
int
w_stride
)
{
int
w_stride
,
const
int
memory_span
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
// (n, c, ph, pw) is an element in the pooled output
int
pw
=
index
%
pooled_width
;
...
...
@@ -104,19 +105,21 @@ __global__ void roi_pool_backward_kernel_impl(
const
T
*
offset_rois
=
rois
+
n
*
5
;
int
roi_batch_ind
=
offset_rois
[
0
];
T
*
grad_input_offset
=
grad_input
+
((
roi_batch_ind
*
channels
+
c
)
*
height
*
width
);
int
output_offset
=
n
*
n_stride
+
c
*
c_stride
;
const
int
output_offset
=
n
*
n_stride
+
c
*
c_stride
;
const
int
*
argmax_data_offset
=
argmax_data
+
(
n
*
channels
+
c
)
*
pooled_height
*
pooled_width
;
int
argmax
=
argmax_data_offset
[
ph
*
pooled_width
+
pw
];
const
int
argmax
=
argmax_data_offset
[
ph
*
pooled_width
+
pw
];
const
int
offset
=
(
roi_batch_ind
*
channels
+
c
)
*
height
*
width
;
if
(
argmax
!=
-
1
)
{
gpuAtomicAdd
(
grad_input_offset
+
argmax
,
at
::
native
::
fastAtomicAdd
(
grad_input
,
offset
+
argmax
,
memory_span
,
static_cast
<
T
>
(
grad_output
[
output_offset
+
ph
*
h_stride
+
pw
*
w_stride
]));
grad_output
[
output_offset
+
ph
*
h_stride
+
pw
*
w_stride
]),
true
);
}
}
}
...
...
@@ -232,6 +235,8 @@ at::Tensor roi_pool_backward_kernel(
int
h_stride
=
grad
.
stride
(
2
);
int
w_stride
=
grad
.
stride
(
3
);
at
::
globalContext
().
alertNotDeterministic
(
"roi_pool_backward_kernel"
);
auto
argmax_
=
argmax
.
contiguous
(),
rois_
=
rois
.
contiguous
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad
.
scalar_type
(),
"roi_pool_backward_kernel"
,
[
&
]
{
...
...
@@ -251,7 +256,8 @@ at::Tensor roi_pool_backward_kernel(
n_stride
,
c_stride
,
h_stride
,
w_stride
);
w_stride
,
grad_input
.
numel
());
});
AT_CUDA_CHECK
(
cudaGetLastError
());
return
grad_input
;
...
...
torchvision/csrc/ops/mps/mps_helpers.h
0 → 100644
View file @
cc26cd81
constexpr
int
threadsPerBlock
=
512
;
template
<
typename
T
>
constexpr
inline
T
ceil_div
(
T
n
,
T
m
)
{
return
(
n
+
m
-
1
)
/
m
;
}
torchvision/csrc/ops/mps/mps_kernels.h
0 → 100644
View file @
cc26cd81
#include <ATen/native/mps/OperationUtils.h>
namespace
vision
{
namespace
ops
{
namespace
mps
{
static
const
char
*
METAL_VISION
=
R"VISION_METAL(
#include <metal_atomic>
#include <metal_stdlib>
using namespace metal;
/*----------Macros----------*/
#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \
for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \
i += (tptg.x * n_tgs))
#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint)
/*----------Helpers--------*/
template <typename T>
inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
template <typename T>
inline void atomic_add_float( device T* data_ptr, const T val)
{
#if __METAL_VERSION__ >= 300
// atomic_float is supported in Metal 3 (macOS Ventura) onward.
device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
#else
// Custom atomic addition implementation
// https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
// https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
// https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
// Create an atomic uint pointer for atomic transaction.
device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
// Create necessary storage.
uint fetched_uint, assigning_uint;
T fetched_float, assigning_float;
// Replace the value in atom_var with 0 and return the previous value in atom_var.
fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
// Read out the previous value as float.
fetched_float = *( (thread T*) &fetched_uint );
// Do addition and represent the addition result in uint for atomic transaction.
assigning_float = fetched_float + val;
assigning_uint = *((thread uint*) &assigning_float);
// atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
// If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
// Try to assign 0 and get the previously assigned addition result.
uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
T fetched_float_again = *( (thread T*) &fetched_uint_again );
// Re-add again
fetched_float = *((thread T*) &(fetched_uint));
// Previously assigned addition result + addition result from other threads.
assigning_float = fetched_float_again + fetched_float;
assigning_uint = *( (thread uint*) &assigning_float);
}
#endif
}
template <typename T, typename integer_t>
inline T bilinear_interpolate(
constant T* input,
integer_t height,
integer_t width,
T y,
T x,
uint index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
return 0;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
integer_t y_low = (integer_t)y;
integer_t x_low = (integer_t)x;
integer_t y_high;
integer_t x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T, typename integer_t>
inline void bilinear_interpolate_gradient(
integer_t height,
integer_t width,
T y,
T x,
thread T& w1,
thread T& w2,
thread T& w3,
thread T& w4,
thread integer_t& x_low,
thread integer_t& x_high,
thread integer_t& y_low,
thread integer_t& y_high,
uint index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (integer_t)y;
x_low = (integer_t)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
}
template <typename T, typename scalar_t>
inline bool IoU(
constant T & a,
threadgroup T & b,
const float threshold) {
auto xx1 = max(a.x, b.x);
auto yy1 = max(a.y, b.y);
auto xx2 = min(a.z, b.z);
auto yy2 = min(a.w, b.w);
auto w = max(static_cast<scalar_t>(0), xx2 - xx1);
auto h = max(static_cast<scalar_t>(0), yy2 - yy1);
// Upcast to float before multiplications to circumvent precision issues in half.
auto inter = static_cast<float>(w) * static_cast<float>(h);
auto area_b = static_cast<float>(b.z - b.x) * static_cast<float>(b.w - b.y);
auto area_a = static_cast<float>(a.z - a.x) * static_cast<float>(a.w - a.y);
return (inter / (area_a + area_b - inter)) > threshold;
}
/*----------Kernels----------*/
// This should be in sync with the one in nms_kernel.mm.
// Since metal does not support dynamic array,
// we need to make it static instead of deriving it from [[threads_per_threadgroup]].
constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8;
template<typename T, typename scalar_t>
kernel void nms(constant T * dev_boxes [[buffer(0)]],
device uint64_t * mask [[buffer(1)]],
constant int64_t & n_boxes [[buffer(2)]],
constant float & iou_threshold [[buffer(3)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tid2 [[thread_position_in_threadgroup]]) {
const uint row_start = tgid.y;
const uint col_start = tgid.x;
const uint tid = tid2.x;
const uint row_size =
min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock);
const uint col_size =
min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock);
threadgroup T block_boxes[nmsThreadsPerBlock];
block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid < row_size) {
const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid;
uint64_t t = 0;
uint start = 0;
if (row_start == col_start) {
start = tid + 1;
}
for (uint i = start; i < col_size; i++){
if (IoU<T, scalar_t>(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){
t |= static_cast<uint64_t>(1) << i; // discard 1 keep 0
}
}
const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock);
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
#define REGISTER_NMS_OP(DTYPE) \
template \
[[host_name("nms_" #DTYPE)]] \
kernel void nms<DTYPE ## 4, DTYPE>( \
constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \
device uint64_t * mask [[buffer(1)]], \
constant int64_t & n_boxes [[buffer(2)]], \
constant float & iou_threshold [[buffer(3)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
constant int64_t & output_size [[buffer(3)]],
constant int64_t & channels [[buffer(4)]],
constant int64_t & height [[buffer(5)]],
constant int64_t & width [[buffer(6)]],
constant int64_t & pooled_height [[buffer(7)]],
constant int64_t & pooled_width [[buffer(8)]],
constant int64_t & sampling_ratio [[buffer(9)]],
constant bool & aligned [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
constant T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4
T output_val = 0.;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
}
output_val /= count;
output[index] = output_val;
}
}
#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_" #DTYPE)]] \
kernel void roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
constant int64_t & output_size [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_align_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * grad_input [[buffer(2)]],
constant int64_t & output_size [[buffer(3)]],
constant int64_t & channels [[buffer(4)]],
constant int64_t & height [[buffer(5)]],
constant int64_t & width [[buffer(6)]],
constant int64_t & pooled_height [[buffer(7)]],
constant int64_t & pooled_width [[buffer(8)]],
constant int64_t & sampling_ratio [[buffer(9)]],
constant bool & aligned [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
constant int64_t & n_stride [[buffer(12)]],
constant int64_t & c_stride [[buffer(13)]],
constant int64_t & h_stride [[buffer(14)]],
constant int64_t & w_stride [[buffer(15)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We need to index the gradient using the tensor strides to access the
// correct values.
const integer_t output_offset = n * n_stride + c * c_stride;
constant T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
const integer_t input_offset = (roi_batch_ind * channels + c) * height * width;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
integer_t x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast<T>(g1));
atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast<T>(g2));
atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast<T>(g3));
atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_backward_" #DTYPE)]] \
kernel void roi_align_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * grad_input [[buffer(2)]], \
constant int64_t & output_size [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
constant int64_t & n_stride [[buffer(12)]], \
constant int64_t & c_stride [[buffer(13)]], \
constant int64_t & h_stride [[buffer(14)]], \
constant int64_t & w_stride [[buffer(15)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_pool(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * argmax [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant float & spatial_scale [[buffer(10)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
integer_t maxidx = -1;
constant T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t input_index = h * width + w;
if (offset_input[input_index] > maxval) {
maxval = offset_input[input_index];
maxidx = input_index;
}
}
}
output[index] = maxval;
argmax[index] = maxidx;
}
}
#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_pool_" #DTYPE)]] \
kernel void roi_pool<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * argmax_data [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant float & spatial_scale [[buffer(10)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_pool_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * argmax_data [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant float & spatial_scale [[buffer(10)]],
constant int64_t & n_stride [[buffer(11)]],
constant int64_t & c_stride [[buffer(12)]],
constant int64_t & h_stride [[buffer(13)]],
constant int64_t & w_stride [[buffer(14)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
const integer_t output_offset = n * n_stride + c * c_stride;
constant integer_t * argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
const integer_t argmax = argmax_data_offset[ph * pooled_width + pw];
const integer_t offset = (roi_batch_ind * channels + c) * height * width;
if (argmax != -1) {
atomic_add_float(grad_input + offset + argmax, static_cast<T>(grad_output[output_offset + ph * h_stride + pw * w_stride]));
}
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_pool_backward_" #DTYPE)]] \
kernel void roi_pool_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * argmax_data [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant float & spatial_scale [[buffer(10)]], \
constant int64_t & n_stride [[buffer(11)]], \
constant int64_t & c_stride [[buffer(12)]], \
constant int64_t & h_stride [[buffer(13)]], \
constant int64_t & w_stride [[buffer(14)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * channel_mapping [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & sampling_ratio [[buffer(10)]],
constant int64_t & channels_out [[buffer(11)]],
constant float & spatial_scale [[buffer(12)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c_out, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c_out = (index / pooled_width / pooled_height) % channels_out;
integer_t n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
constant T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
out_sum += val;
}
}
out_sum /= count;
output[index] = out_sum;
channel_mapping[index] = c_in;
}
}
#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_align_" #DTYPE)]] \
kernel void ps_roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * channel_mapping [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & sampling_ratio [[buffer(10)]], \
constant int64_t & channels_out [[buffer(11)]], \
constant float & spatial_scale [[buffer(12)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_align_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * channel_mapping [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & sampling_ratio [[buffer(10)]],
constant int64_t & channels_out [[buffer(11)]],
constant float & spatial_scale [[buffer(12)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, *, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t n = index / pooled_width / pooled_height / channels_out;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
// Force too small ROIs to be 1x1
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
integer_t c_in = channel_mapping[index];
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
const T grad_output_this_bin = grad_output[index];
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const integer_t offset = (roi_batch_ind * channels + c_in) * height * width;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
integer_t x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast<T>(g1));
atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast<T>(g2));
atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast<T>(g3));
atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
}
}
#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_align_backward_" #DTYPE)]] \
kernel void ps_roi_align_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * channel_mapping [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & sampling_ratio [[buffer(10)]], \
constant int64_t & channels_out [[buffer(11)]], \
constant float & spatial_scale [[buffer(12)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_pool(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * channel_mapping [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & channels_out [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c_out, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out;
integer_t n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height - 1));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height - 1));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width - 1));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width - 1));
bool is_empty = (hend <= hstart) || (wend <= wstart);
constant T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t input_index = h * width + w;
out_sum += offset_input[input_index];
}
}
T bin_area = (hend - hstart) * (wend - wstart);
output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
channel_mapping[index] = c_in;
}
}
#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_pool_" #DTYPE)]] \
kernel void ps_roi_pool<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * channel_mapping [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_pool_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * channel_mapping [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & channels_out [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, *, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t n = index / pooled_width / pooled_height / channels_out;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
bool is_empty = (hend <= hstart) || (wend <= wstart);
integer_t c_in = channel_mapping[index];
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
const integer_t offset = (roi_batch_ind * channels + c_in) * height * width;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t grad_input_index = h * width + w;
atomic_add_float(grad_input + offset + grad_input_index, diff_val);
}
}
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_pool_backward_" #DTYPE)]] \
kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * channel_mapping [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
REGISTER_NMS_OP(float);
REGISTER_NMS_OP(half);
REGISTER_ROI_ALIGN_OP(float, int64_t);
REGISTER_ROI_ALIGN_OP(half, int64_t);
REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t);
REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t);
REGISTER_ROI_POOL_OP(float, int64_t);
REGISTER_ROI_POOL_OP(half, int64_t);
REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t);
REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t);
REGISTER_PS_ROI_ALIGN_OP(float, int64_t);
REGISTER_PS_ROI_ALIGN_OP(half, int64_t);
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t);
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t);
REGISTER_PS_ROI_POOL_OP(float, int64_t);
REGISTER_PS_ROI_POOL_OP(half, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
)VISION_METAL"
;
static
id
<
MTLLibrary
>
compileVisionOpsLibrary
(
id
<
MTLDevice
>
device
)
{
static
id
<
MTLLibrary
>
visionLibrary
=
nil
;
if
(
visionLibrary
)
{
return
visionLibrary
;
}
NSError
*
error
=
nil
;
MTLCompileOptions
*
options
=
[[
MTLCompileOptions
new
]
autorelease
];
[
options
setLanguageVersion
:
MTLLanguageVersion2_3
];
visionLibrary
=
[
device
newLibraryWithSource
:
[
NSString
stringWithCString
:
METAL_VISION
encoding
:
NSASCIIStringEncoding
]
options:
options
error:
&
error
];
TORCH_CHECK
(
visionLibrary
,
"Failed to create metal vision library, error: "
,
[[
error
description
]
UTF8String
]);
return
visionLibrary
;
}
static
id
<
MTLComputePipelineState
>
visionPipelineState
(
id
<
MTLDevice
>
device
,
const
std
::
string
&
kernel
)
{
static
std
::
unordered_map
<
std
::
string
,
id
<
MTLComputePipelineState
>>
psoCache
;
id
<
MTLComputePipelineState
>
pso
=
psoCache
[
kernel
];
if
(
pso
)
{
return
pso
;
}
NSError
*
error
=
nil
;
id
<
MTLLibrary
>
visionLib
=
compileVisionOpsLibrary
(
device
);
id
<
MTLFunction
>
visionFunc
=
[
visionLib
newFunctionWithName
:
[
NSString
stringWithUTF8String
:
kernel
.
c_str
()]];
TORCH_CHECK
(
visionFunc
,
"Failed to create function state object for: "
,
kernel
);
pso
=
[
device
newComputePipelineStateWithFunction
:
visionFunc
error
:&
error
];
TORCH_CHECK
(
pso
,
"Failed to created pipeline state object, error: "
,
[[
error
description
]
UTF8String
]);
psoCache
[
kernel
]
=
pso
;
return
pso
;
}
}
// namespace mps
}
// namespace ops
}
// namespace vision
torchvision/csrc/ops/mps/nms_kernel.mm
0 → 100644
View file @
cc26cd81
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_kernels.h"
namespace
vision
{
namespace
ops
{
namespace
{
// This should be in sync with `nmsThreadsPerBlock` in the metal kernel.
constexpr
int64_t
nmsThreadsPerBlock
=
sizeof
(
uint64_t
)
*
8
;
at
::
Tensor
nms_kernel
(
const
at
::
Tensor
&
dets
,
const
at
::
Tensor
&
scores
,
double
iou_threshold
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
dets
.
is_mps
(),
"dets must be a MPS tensor"
);
TORCH_CHECK
(
scores
.
is_mps
(),
"scores must be a MPS tensor"
);
TORCH_CHECK
(
dets
.
dim
()
==
2
,
"boxes should be a 2d tensor, got "
,
dets
.
dim
(),
"D"
);
TORCH_CHECK
(
dets
.
size
(
1
)
==
4
,
"boxes should have 4 elements in dimension 1, got "
,
dets
.
size
(
1
));
TORCH_CHECK
(
scores
.
dim
()
==
1
,
"scores should be a 1d tensor, got "
,
scores
.
dim
(),
"D"
);
TORCH_CHECK
(
dets
.
size
(
0
)
==
scores
.
size
(
0
),
"boxes and scores should have same number of elements in "
,
"dimension 0, got "
,
dets
.
size
(
0
),
" and "
,
scores
.
size
(
0
))
if
(
dets
.
numel
()
==
0
)
{
return
at
::
empty
({
0
},
dets
.
options
().
dtype
(
at
::
kLong
));
}
auto
order_t
=
std
::
get
<
1
>
(
scores
.
sort
(
/*stable=*/
true
,
/*dim=*/
0
,
/* descending=*/
true
));
auto
dets_sorted
=
dets
.
index_select
(
0
,
order_t
).
contiguous
();
int64_t
dets_num
=
dets
.
size
(
0
);
float
iou_threshold_f
=
static_cast
<
float
>
(
iou_threshold
);
const
int
col_blocks
=
(
dets_num
+
nmsThreadsPerBlock
-
1
)
/
nmsThreadsPerBlock
;
at
::
Tensor
mask
=
at
::
empty
({
dets_num
*
col_blocks
},
dets
.
options
().
dtype
(
at
::
kLong
));
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
dets_sorted
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
mask
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
col_blocks
,
col_blocks
,
1
);
const
std
::
string
kernel
=
"nms_"
+
scalarToMetalTypeString
(
dets_sorted
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
dets
,
scores
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
dets_sorted
.
storage_offset
()
*
dets_sorted
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
mask
.
storage_offset
()
*
mask
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBytes
:
&
dets_num
length
:
sizeof
(
int64_t
)
atIndex
:
2
];
[
computeEncoder
setBytes
:
&
iou_threshold_f
length
:
sizeof
(
float
)
atIndex
:
3
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
nmsThreadsPerBlock
)
{
tgSize
=
nmsThreadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
int64_t
num_to_keep
=
0
;
at
::
Tensor
mask_cpu
=
mask
.
to
(
at
::
kCPU
);
unsigned
long
long
*
mask_host
=
(
unsigned
long
long
*
)
mask_cpu
.
data_ptr
<
int64_t
>
();
std
::
vector
<
unsigned
long
long
>
remv
(
col_blocks
);
memset
(
&
remv
[
0
],
0
,
sizeof
(
unsigned
long
long
)
*
col_blocks
);
at
::
Tensor
keep
=
at
::
empty
({
dets_num
},
dets
.
options
().
dtype
(
at
::
kLong
).
device
(
at
::
kCPU
));
int64_t
*
keep_out
=
keep
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
dets_num
;
i
++
)
{
int64_t
nblock
=
i
/
nmsThreadsPerBlock
;
int64_t
inblock
=
i
%
nmsThreadsPerBlock
;
if
(
!
(
remv
[
nblock
]
&
(
1ULL
<<
inblock
)))
{
keep_out
[
num_to_keep
++
]
=
i
;
unsigned
long
long
*
p
=
mask_host
+
i
*
col_blocks
;
for
(
int64_t
j
=
nblock
;
j
<
col_blocks
;
j
++
)
{
remv
[
j
]
|=
p
[
j
];
}
}
}
return
order_t
.
index
(
{
keep
.
narrow
(
/*dim=*/
0
,
/*start=*/
0
,
/*length=*/
num_to_keep
).
to
(
order_t
.
device
(),
keep
.
scalar_type
())});
}
}
// namespace
TORCH_LIBRARY_IMPL
(
torchvision
,
MPS
,
m
)
{
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::nms"
),
TORCH_FN
(
nms_kernel
));
}
}
// namespace ops
}
// namespace vision
torchvision/csrc/ops/mps/ps_roi_align_kernel.mm
0 → 100644
View file @
cc26cd81
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace
vision
{
namespace
ops
{
namespace
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
ps_roi_align_forward_kernel
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
int64_t
pooled_height
,
int64_t
pooled_width
,
int64_t
sampling_ratio
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
input
.
is_mps
(),
"input must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
is_mps
(),
"rois must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
size
(
1
)
==
5
,
"rois must have shape as Tensor[K, 5]"
);
at
::
TensorArg
input_t
{
input
,
"input"
,
1
},
rois_t
{
rois
,
"rois"
,
2
};
at
::
CheckedFrom
c
=
"ps_roi_align_forward_kernel"
;
at
::
checkAllSameGPU
(
c
,
{
input_t
,
rois_t
});
at
::
checkAllSameType
(
c
,
{
input_t
,
rois_t
});
int64_t
num_rois
=
rois
.
size
(
0
);
int64_t
channels
=
input
.
size
(
1
);
int64_t
height
=
input
.
size
(
2
);
int64_t
width
=
input
.
size
(
3
);
float
spatial_scale_f
=
static_cast
<
float
>
(
spatial_scale
);
TORCH_CHECK
(
channels
%
(
pooled_height
*
pooled_width
)
==
0
,
"input channels must be a multiple of pooling height * pooling width"
);
int64_t
channels_out
=
channels
/
(
pooled_height
*
pooled_width
);
auto
output
=
at
::
zeros
({
num_rois
,
channels_out
,
pooled_height
,
pooled_width
},
input
.
options
());
auto
channel_mapping
=
at
::
zeros
(
output
.
sizes
(),
input
.
options
().
dtype
(
at
::
kLong
));
int64_t
output_size
=
output
.
numel
();
if
(
output_size
==
0
)
{
return
std
::
make_tuple
(
output
,
channel_mapping
);
}
auto
input_
=
input
.
contiguous
();
auto
rois_
=
rois
.
contiguous
();
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
input_
);
id
<
MTLBuffer
>
roisBuffer
=
getMTLBufferStorage
(
rois_
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
output
);
id
<
MTLBuffer
>
channelMappingBuffer
=
getMTLBufferStorage
(
channel_mapping
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
std
::
min
(
ceil_div
(
static_cast
<
int64_t
>
(
output_size
),
static_cast
<
int64_t
>
(
512
)),
static_cast
<
int64_t
>
(
4096
)),
1
,
1
);
const
std
::
string
kernel
=
"ps_roi_align_"
+
scalarToMetalTypeString
(
input
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
input_
,
rois_
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
// [N, C, H, W]
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
input_
.
storage_offset
()
*
input_
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
roisBuffer
offset
:
rois_
.
storage_offset
()
*
rois_
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
output
.
storage_offset
()
*
output
.
element_size
()
atIndex
:
2
];
[
computeEncoder
setBuffer
:
channelMappingBuffer
offset:
channel_mapping
.
storage_offset
()
*
channel_mapping
.
element_size
()
atIndex:
3
];
[
computeEncoder
setBytes
:
&
output_size
length
:
sizeof
(
int64_t
)
atIndex
:
4
];
[
computeEncoder
setBytes
:
&
channels
length
:
sizeof
(
int64_t
)
atIndex
:
5
];
[
computeEncoder
setBytes
:
&
height
length
:
sizeof
(
int64_t
)
atIndex
:
6
];
[
computeEncoder
setBytes
:
&
width
length
:
sizeof
(
int64_t
)
atIndex
:
7
];
[
computeEncoder
setBytes
:
&
pooled_height
length
:
sizeof
(
int64_t
)
atIndex
:
8
];
[
computeEncoder
setBytes
:
&
pooled_width
length
:
sizeof
(
int64_t
)
atIndex
:
9
];
[
computeEncoder
setBytes
:
&
sampling_ratio
length
:
sizeof
(
int64_t
)
atIndex
:
10
];
[
computeEncoder
setBytes
:
&
channels_out
length
:
sizeof
(
int64_t
)
atIndex
:
11
];
[
computeEncoder
setBytes
:
&
spatial_scale_f
length
:
sizeof
(
float
)
atIndex
:
12
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
threadsPerBlock
)
{
tgSize
=
threadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
return
std
::
make_tuple
(
output
,
channel_mapping
);
}
at
::
Tensor
ps_roi_align_backward_kernel
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
rois
,
const
at
::
Tensor
&
channel_mapping
,
double
spatial_scale
,
int64_t
pooled_height
,
int64_t
pooled_width
,
int64_t
sampling_ratio
,
int64_t
batch_size
,
int64_t
channels
,
int64_t
height
,
int64_t
width
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
grad
.
is_mps
(),
"grad must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
is_mps
(),
"rois must be a MPS tensor"
);
TORCH_CHECK
(
grad
.
scalar_type
()
!=
at
::
kHalf
,
"MPS does not support ps_roi_align backward with float16 inputs."
);
TORCH_CHECK
(
channel_mapping
.
is_mps
(),
"channel_mapping must be a MPS tensor"
);
at
::
TensorArg
grad_t
{
grad
,
"input"
,
1
},
rois_t
{
rois
,
"rois"
,
2
},
channel_mapping_t
{
channel_mapping
,
"channel_mapping"
,
3
};
at
::
CheckedFrom
c
=
"ps_roi_align_backward_kernel"
;
at
::
checkAllSameGPU
(
c
,
{
grad_t
,
rois_t
,
channel_mapping_t
});
at
::
checkAllSameType
(
c
,
{
grad_t
,
rois_t
});
float
spatial_scale_f
=
static_cast
<
float
>
(
spatial_scale
);
auto
grad_input
=
at
::
zeros
({
batch_size
,
channels
,
height
,
width
},
grad
.
options
());
if
(
grad
.
numel
()
==
0
)
{
return
grad_input
;
}
int64_t
output_size
=
grad
.
numel
();
int64_t
channels_out
=
channels
/
(
pooled_height
*
pooled_width
);
at
::
globalContext
().
alertNotDeterministic
(
"ps_roi_align_backward_kernel"
);
auto
grad_
=
grad
.
contiguous
(),
rois_
=
rois
.
contiguous
();
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
grad_
);
id
<
MTLBuffer
>
roisBuffer
=
getMTLBufferStorage
(
rois_
);
id
<
MTLBuffer
>
channelMappingBuffer
=
getMTLBufferStorage
(
channel_mapping
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
grad_input
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
std
::
min
(
ceil_div
(
static_cast
<
int64_t
>
(
grad
.
numel
()),
static_cast
<
int64_t
>
(
512
)),
static_cast
<
int64_t
>
(
4096
)),
1
,
1
);
const
std
::
string
kernel
=
"ps_roi_align_backward_"
+
scalarToMetalTypeString
(
grad
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
grad
,
rois_
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
// [N, C, H, W]
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
grad_
.
storage_offset
()
*
grad_
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
roisBuffer
offset
:
rois_
.
storage_offset
()
*
rois_
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBuffer
:
channelMappingBuffer
offset:
channel_mapping
.
storage_offset
()
*
channel_mapping
.
element_size
()
atIndex:
2
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
grad_input
.
storage_offset
()
*
grad_input
.
element_size
()
atIndex
:
3
];
[
computeEncoder
setBytes
:
&
output_size
length
:
sizeof
(
int64_t
)
atIndex
:
4
];
[
computeEncoder
setBytes
:
&
channels
length
:
sizeof
(
int64_t
)
atIndex
:
5
];
[
computeEncoder
setBytes
:
&
height
length
:
sizeof
(
int64_t
)
atIndex
:
6
];
[
computeEncoder
setBytes
:
&
width
length
:
sizeof
(
int64_t
)
atIndex
:
7
];
[
computeEncoder
setBytes
:
&
pooled_height
length
:
sizeof
(
int64_t
)
atIndex
:
8
];
[
computeEncoder
setBytes
:
&
pooled_width
length
:
sizeof
(
int64_t
)
atIndex
:
9
];
[
computeEncoder
setBytes
:
&
sampling_ratio
length
:
sizeof
(
int64_t
)
atIndex
:
10
];
[
computeEncoder
setBytes
:
&
channels_out
length
:
sizeof
(
int64_t
)
atIndex
:
11
];
[
computeEncoder
setBytes
:
&
spatial_scale_f
length
:
sizeof
(
float
)
atIndex
:
12
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
threadsPerBlock
)
{
tgSize
=
threadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
return
grad_input
;
}
}
// namespace
TORCH_LIBRARY_IMPL
(
torchvision
,
MPS
,
m
)
{
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::ps_roi_align"
),
TORCH_FN
(
ps_roi_align_forward_kernel
));
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::_ps_roi_align_backward"
),
TORCH_FN
(
ps_roi_align_backward_kernel
));
}
}
// namespace ops
}
// namespace vision
torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm
0 → 100644
View file @
cc26cd81
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace
vision
{
namespace
ops
{
namespace
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
ps_roi_pool_forward_kernel
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
int64_t
pooled_height
,
int64_t
pooled_width
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
input
.
is_mps
(),
"input must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
is_mps
(),
"rois must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
size
(
1
)
==
5
,
"rois must have shape as Tensor[K, 5]"
);
at
::
TensorArg
input_t
{
input
,
"input"
,
1
},
rois_t
{
rois
,
"rois"
,
2
};
at
::
CheckedFrom
c
=
"ps_roi_pool_forward_kernel"
;
at
::
checkAllSameGPU
(
c
,
{
input_t
,
rois_t
});
at
::
checkAllSameType
(
c
,
{
input_t
,
rois_t
});
int64_t
num_rois
=
rois
.
size
(
0
);
int64_t
channels
=
input
.
size
(
1
);
int64_t
height
=
input
.
size
(
2
);
int64_t
width
=
input
.
size
(
3
);
float
spatial_scale_f
=
static_cast
<
float
>
(
spatial_scale
);
TORCH_CHECK
(
channels
%
(
pooled_height
*
pooled_width
)
==
0
,
"input channels must be a multiple of pooling height * pooling width"
);
int64_t
channels_out
=
channels
/
(
pooled_height
*
pooled_width
);
auto
output
=
at
::
zeros
({
num_rois
,
channels_out
,
pooled_height
,
pooled_width
},
input
.
options
());
auto
channel_mapping
=
at
::
zeros
(
output
.
sizes
(),
input
.
options
().
dtype
(
at
::
kLong
));
auto
output_size
=
output
.
numel
();
if
(
output_size
==
0
)
{
return
std
::
make_tuple
(
output
,
channel_mapping
);
}
auto
input_
=
input
.
contiguous
();
auto
rois_
=
rois
.
contiguous
();
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
input_
);
id
<
MTLBuffer
>
roisBuffer
=
getMTLBufferStorage
(
rois_
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
output
);
id
<
MTLBuffer
>
channelMappingBuffer
=
getMTLBufferStorage
(
channel_mapping
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
std
::
min
(
ceil_div
(
static_cast
<
int64_t
>
(
output_size
),
static_cast
<
int64_t
>
(
512
)),
static_cast
<
int64_t
>
(
4096
)),
1
,
1
);
const
std
::
string
kernel
=
"ps_roi_pool_"
+
scalarToMetalTypeString
(
input
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
input_
,
rois_
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
// [N, C, H, W]
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
input_
.
storage_offset
()
*
input_
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
roisBuffer
offset
:
rois_
.
storage_offset
()
*
rois_
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
output
.
storage_offset
()
*
output
.
element_size
()
atIndex
:
2
];
[
computeEncoder
setBuffer
:
channelMappingBuffer
offset:
channel_mapping
.
storage_offset
()
*
channel_mapping
.
element_size
()
atIndex:
3
];
[
computeEncoder
setBytes
:
&
output_size
length
:
sizeof
(
int64_t
)
atIndex
:
4
];
[
computeEncoder
setBytes
:
&
channels
length
:
sizeof
(
int64_t
)
atIndex
:
5
];
[
computeEncoder
setBytes
:
&
height
length
:
sizeof
(
int64_t
)
atIndex
:
6
];
[
computeEncoder
setBytes
:
&
width
length
:
sizeof
(
int64_t
)
atIndex
:
7
];
[
computeEncoder
setBytes
:
&
pooled_height
length
:
sizeof
(
int64_t
)
atIndex
:
8
];
[
computeEncoder
setBytes
:
&
pooled_width
length
:
sizeof
(
int64_t
)
atIndex
:
9
];
[
computeEncoder
setBytes
:
&
channels_out
length
:
sizeof
(
int64_t
)
atIndex
:
10
];
[
computeEncoder
setBytes
:
&
spatial_scale_f
length
:
sizeof
(
float
)
atIndex
:
11
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
threadsPerBlock
)
{
tgSize
=
threadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
return
std
::
make_tuple
(
output
,
channel_mapping
);
}
at
::
Tensor
ps_roi_pool_backward_kernel
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
rois
,
const
at
::
Tensor
&
channel_mapping
,
double
spatial_scale
,
int64_t
pooled_height
,
int64_t
pooled_width
,
int64_t
batch_size
,
int64_t
channels
,
int64_t
height
,
int64_t
width
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
grad
.
is_mps
(),
"grad must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
is_mps
(),
"rois must be a MPS tensor"
);
TORCH_CHECK
(
grad
.
scalar_type
()
!=
at
::
kHalf
,
"MPS does not support ps_roi_pool backward with float16 inputs."
);
TORCH_CHECK
(
channel_mapping
.
is_mps
(),
"channel_mapping must be a MPS tensor"
);
at
::
TensorArg
grad_t
{
grad
,
"grad"
,
1
},
rois_t
{
rois
,
"rois"
,
2
},
channel_mapping_t
{
channel_mapping
,
"channel_mapping"
,
3
};
at
::
CheckedFrom
c
=
"ps_roi_pool_backward_kernel"
;
at
::
checkAllSameGPU
(
c
,
{
grad_t
,
rois_t
,
channel_mapping_t
});
at
::
checkAllSameType
(
c
,
{
grad_t
,
rois_t
});
float
spatial_scale_f
=
static_cast
<
float
>
(
spatial_scale
);
auto
num_rois
=
rois
.
size
(
0
);
auto
grad_input
=
at
::
zeros
({
batch_size
,
channels
,
height
,
width
},
grad
.
options
());
if
(
grad
.
numel
()
==
0
)
{
return
grad_input
;
}
int64_t
channels_out
=
channels
/
(
pooled_height
*
pooled_width
);
int64_t
output_size
=
grad
.
numel
();
at
::
globalContext
().
alertNotDeterministic
(
"ps_roi_pool_backward_kernel"
);
auto
grad_
=
grad
.
contiguous
(),
rois_
=
rois
.
contiguous
();
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
grad_
);
id
<
MTLBuffer
>
roisBuffer
=
getMTLBufferStorage
(
rois_
);
id
<
MTLBuffer
>
channelMappingBuffer
=
getMTLBufferStorage
(
channel_mapping
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
grad_input
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
std
::
min
(
ceil_div
(
static_cast
<
int64_t
>
(
grad
.
numel
()),
static_cast
<
int64_t
>
(
512
)),
static_cast
<
int64_t
>
(
4096
)),
1
,
1
);
const
std
::
string
kernel
=
"ps_roi_pool_backward_"
+
scalarToMetalTypeString
(
grad
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
grad_
,
rois_
,
channel_mapping
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
// [N, C, H, W]
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
grad_
.
storage_offset
()
*
grad_
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
roisBuffer
offset
:
rois_
.
storage_offset
()
*
rois_
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBuffer
:
channelMappingBuffer
offset:
channel_mapping
.
storage_offset
()
*
channel_mapping
.
element_size
()
atIndex:
2
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
grad_input
.
storage_offset
()
*
grad_input
.
element_size
()
atIndex
:
3
];
[
computeEncoder
setBytes
:
&
output_size
length
:
sizeof
(
int64_t
)
atIndex
:
4
];
[
computeEncoder
setBytes
:
&
channels
length
:
sizeof
(
int64_t
)
atIndex
:
5
];
[
computeEncoder
setBytes
:
&
height
length
:
sizeof
(
int64_t
)
atIndex
:
6
];
[
computeEncoder
setBytes
:
&
width
length
:
sizeof
(
int64_t
)
atIndex
:
7
];
[
computeEncoder
setBytes
:
&
pooled_height
length
:
sizeof
(
int64_t
)
atIndex
:
8
];
[
computeEncoder
setBytes
:
&
pooled_width
length
:
sizeof
(
int64_t
)
atIndex
:
9
];
[
computeEncoder
setBytes
:
&
channels_out
length
:
sizeof
(
int64_t
)
atIndex
:
10
];
[
computeEncoder
setBytes
:
&
spatial_scale_f
length
:
sizeof
(
float
)
atIndex
:
11
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
threadsPerBlock
)
{
tgSize
=
threadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
return
grad_input
;
}
}
// namespace
TORCH_LIBRARY_IMPL
(
torchvision
,
MPS
,
m
)
{
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::ps_roi_pool"
),
TORCH_FN
(
ps_roi_pool_forward_kernel
));
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::_ps_roi_pool_backward"
),
TORCH_FN
(
ps_roi_pool_backward_kernel
));
}
}
// namespace ops
}
// namespace vision
torchvision/csrc/ops/mps/roi_align_kernel.mm
0 → 100644
View file @
cc26cd81
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace
vision
{
namespace
ops
{
namespace
{
at
::
Tensor
roi_align_forward_kernel
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
int64_t
pooled_height
,
int64_t
pooled_width
,
int64_t
sampling_ratio
,
bool
aligned
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
input
.
is_mps
(),
"input must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
is_mps
(),
"rois must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
size
(
1
)
==
5
,
"rois must have shape as Tensor[K, 5]"
);
at
::
TensorArg
input_t
{
input
,
"input"
,
1
},
rois_t
{
rois
,
"rois"
,
2
};
at
::
CheckedFrom
c
=
"roi_align_forward_kernel"
;
at
::
checkAllSameGPU
(
c
,
{
input_t
,
rois_t
});
at
::
checkAllSameType
(
c
,
{
input_t
,
rois_t
});
int64_t
num_rois
=
rois
.
size
(
0
);
int64_t
channels
=
input
.
size
(
1
);
int64_t
height
=
input
.
size
(
2
);
int64_t
width
=
input
.
size
(
3
);
float
spatial_scale_f
=
static_cast
<
float
>
(
spatial_scale
);
at
::
Tensor
output
=
at
::
zeros
({
num_rois
,
channels
,
pooled_height
,
pooled_width
},
input
.
options
());
int64_t
output_size
=
num_rois
*
pooled_height
*
pooled_width
*
channels
;
if
(
output
.
numel
()
==
0
)
{
return
output
;
}
auto
input_
=
input
.
contiguous
();
auto
rois_
=
rois
.
contiguous
();
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
input_
);
id
<
MTLBuffer
>
roisBuffer
=
getMTLBufferStorage
(
rois_
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
output
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
std
::
min
(
ceil_div
(
static_cast
<
int64_t
>
(
output_size
),
static_cast
<
int64_t
>
(
512
)),
static_cast
<
int64_t
>
(
4096
)),
1
,
1
);
const
std
::
string
kernel
=
"roi_align_"
+
scalarToMetalTypeString
(
input
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
input_
,
rois_
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
// [N, C, H, W]
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
input_
.
storage_offset
()
*
input_
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
roisBuffer
offset
:
rois_
.
storage_offset
()
*
rois_
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
output
.
storage_offset
()
*
output
.
element_size
()
atIndex
:
2
];
[
computeEncoder
setBytes
:
&
output_size
length
:
sizeof
(
int64_t
)
atIndex
:
3
];
[
computeEncoder
setBytes
:
&
channels
length
:
sizeof
(
int64_t
)
atIndex
:
4
];
[
computeEncoder
setBytes
:
&
height
length
:
sizeof
(
int64_t
)
atIndex
:
5
];
[
computeEncoder
setBytes
:
&
width
length
:
sizeof
(
int64_t
)
atIndex
:
6
];
[
computeEncoder
setBytes
:
&
pooled_height
length
:
sizeof
(
int64_t
)
atIndex
:
7
];
[
computeEncoder
setBytes
:
&
pooled_width
length
:
sizeof
(
int64_t
)
atIndex
:
8
];
[
computeEncoder
setBytes
:
&
sampling_ratio
length
:
sizeof
(
int64_t
)
atIndex
:
9
];
[
computeEncoder
setBytes
:
&
aligned
length
:
sizeof
(
bool
)
atIndex
:
10
];
[
computeEncoder
setBytes
:
&
spatial_scale_f
length
:
sizeof
(
float
)
atIndex
:
11
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
threadsPerBlock
)
{
tgSize
=
threadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
return
output
;
}
at
::
Tensor
roi_align_backward_kernel
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
int64_t
pooled_height
,
int64_t
pooled_width
,
int64_t
batch_size
,
int64_t
channels
,
int64_t
height
,
int64_t
width
,
int64_t
sampling_ratio
,
bool
aligned
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
grad
.
is_mps
(),
"grad must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
is_mps
(),
"rois must be a MPS tensor"
);
TORCH_CHECK
(
grad
.
scalar_type
()
!=
at
::
kHalf
,
"MPS does not support roi_align backward with float16 inputs."
);
at
::
TensorArg
grad_t
{
grad
,
"input"
,
1
},
rois_t
{
rois
,
"rois"
,
2
};
at
::
CheckedFrom
c
=
"roi_align_backward_kernel"
;
at
::
checkAllSameGPU
(
c
,
{
grad_t
,
rois_t
});
at
::
checkAllSameType
(
c
,
{
grad_t
,
rois_t
});
float
spatial_scale_f
=
static_cast
<
float
>
(
spatial_scale
);
at
::
Tensor
grad_input
=
at
::
zeros
({
batch_size
,
channels
,
height
,
width
},
grad
.
options
());
if
(
grad
.
numel
()
==
0
)
{
return
grad_input
;
}
int64_t
n_stride
=
grad
.
stride
(
0
);
int64_t
c_stride
=
grad
.
stride
(
1
);
int64_t
h_stride
=
grad
.
stride
(
2
);
int64_t
w_stride
=
grad
.
stride
(
3
);
int64_t
output_size
=
grad
.
numel
();
at
::
globalContext
().
alertNotDeterministic
(
"roi_align_backward_kernel"
);
auto
rois_
=
rois
.
contiguous
();
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
grad
);
id
<
MTLBuffer
>
roisBuffer
=
getMTLBufferStorage
(
rois_
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
grad_input
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
std
::
min
(
ceil_div
(
static_cast
<
int64_t
>
(
grad
.
numel
()),
static_cast
<
int64_t
>
(
512
)),
static_cast
<
int64_t
>
(
4096
)),
1
,
1
);
const
std
::
string
kernel
=
"roi_align_backward_"
+
scalarToMetalTypeString
(
grad
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
grad
,
rois_
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
// [N, C, H, W]
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
grad
.
storage_offset
()
*
grad
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
roisBuffer
offset
:
rois_
.
storage_offset
()
*
rois_
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
grad_input
.
storage_offset
()
*
grad_input
.
element_size
()
atIndex
:
2
];
[
computeEncoder
setBytes
:
&
output_size
length
:
sizeof
(
int64_t
)
atIndex
:
3
];
[
computeEncoder
setBytes
:
&
channels
length
:
sizeof
(
int64_t
)
atIndex
:
4
];
[
computeEncoder
setBytes
:
&
height
length
:
sizeof
(
int64_t
)
atIndex
:
5
];
[
computeEncoder
setBytes
:
&
width
length
:
sizeof
(
int64_t
)
atIndex
:
6
];
[
computeEncoder
setBytes
:
&
pooled_height
length
:
sizeof
(
int64_t
)
atIndex
:
7
];
[
computeEncoder
setBytes
:
&
pooled_width
length
:
sizeof
(
int64_t
)
atIndex
:
8
];
[
computeEncoder
setBytes
:
&
sampling_ratio
length
:
sizeof
(
int64_t
)
atIndex
:
9
];
[
computeEncoder
setBytes
:
&
aligned
length
:
sizeof
(
bool
)
atIndex
:
10
];
[
computeEncoder
setBytes
:
&
spatial_scale_f
length
:
sizeof
(
float
)
atIndex
:
11
];
[
computeEncoder
setBytes
:
&
n_stride
length
:
sizeof
(
int64_t
)
atIndex
:
12
];
[
computeEncoder
setBytes
:
&
c_stride
length
:
sizeof
(
int64_t
)
atIndex
:
13
];
[
computeEncoder
setBytes
:
&
h_stride
length
:
sizeof
(
int64_t
)
atIndex
:
14
];
[
computeEncoder
setBytes
:
&
w_stride
length
:
sizeof
(
int64_t
)
atIndex
:
15
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
threadsPerBlock
)
{
tgSize
=
threadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
return
grad_input
;
}
}
// namespace
TORCH_LIBRARY_IMPL
(
torchvision
,
MPS
,
m
)
{
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::roi_align"
),
TORCH_FN
(
roi_align_forward_kernel
));
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::_roi_align_backward"
),
TORCH_FN
(
roi_align_backward_kernel
));
}
}
// namespace ops
}
// namespace vision
torchvision/csrc/ops/mps/roi_pool_kernel.mm
0 → 100644
View file @
cc26cd81
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace
vision
{
namespace
ops
{
namespace
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
roi_pool_forward_kernel
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
int64_t
pooled_height
,
int64_t
pooled_width
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
input
.
is_mps
(),
"input must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
is_mps
(),
"rois must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
size
(
1
)
==
5
,
"rois must have shape as Tensor[K, 5]"
);
at
::
TensorArg
input_t
{
input
,
"input"
,
1
},
rois_t
{
rois
,
"rois"
,
2
};
at
::
CheckedFrom
c
=
"roi_pool_forward_kernel"
;
at
::
checkAllSameGPU
(
c
,
{
input_t
,
rois_t
});
at
::
checkAllSameType
(
c
,
{
input_t
,
rois_t
});
int64_t
num_rois
=
rois
.
size
(
0
);
int64_t
channels
=
input
.
size
(
1
);
int64_t
height
=
input
.
size
(
2
);
int64_t
width
=
input
.
size
(
3
);
float
spatial_scale_f
=
static_cast
<
float
>
(
spatial_scale
);
at
::
Tensor
output
=
at
::
zeros
({
num_rois
,
channels
,
pooled_height
,
pooled_width
},
input
.
options
());
at
::
Tensor
argmax
=
at
::
zeros
({
num_rois
,
channels
,
pooled_height
,
pooled_width
},
input
.
options
().
dtype
(
at
::
kLong
));
int64_t
output_size
=
num_rois
*
pooled_height
*
pooled_width
*
channels
;
if
(
output
.
numel
()
==
0
)
{
return
std
::
make_tuple
(
output
,
argmax
);
}
auto
input_
=
input
.
contiguous
();
auto
rois_
=
rois
.
contiguous
();
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
input_
);
id
<
MTLBuffer
>
roisBuffer
=
getMTLBufferStorage
(
rois_
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
output
);
id
<
MTLBuffer
>
argmaxBuffer
=
getMTLBufferStorage
(
argmax
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
std
::
min
(
ceil_div
(
static_cast
<
int64_t
>
(
output_size
),
static_cast
<
int64_t
>
(
512
)),
static_cast
<
int64_t
>
(
4096
)),
1
,
1
);
const
std
::
string
kernel
=
"roi_pool_"
+
scalarToMetalTypeString
(
input
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
input_
,
rois_
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
// [N, C, H, W]
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
input_
.
storage_offset
()
*
input_
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
roisBuffer
offset
:
rois_
.
storage_offset
()
*
rois_
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
output
.
storage_offset
()
*
output
.
element_size
()
atIndex
:
2
];
[
computeEncoder
setBuffer
:
argmaxBuffer
offset
:
argmax
.
storage_offset
()
*
argmax
.
element_size
()
atIndex
:
3
];
[
computeEncoder
setBytes
:
&
output_size
length
:
sizeof
(
int64_t
)
atIndex
:
4
];
[
computeEncoder
setBytes
:
&
channels
length
:
sizeof
(
int64_t
)
atIndex
:
5
];
[
computeEncoder
setBytes
:
&
height
length
:
sizeof
(
int64_t
)
atIndex
:
6
];
[
computeEncoder
setBytes
:
&
width
length
:
sizeof
(
int64_t
)
atIndex
:
7
];
[
computeEncoder
setBytes
:
&
pooled_height
length
:
sizeof
(
int64_t
)
atIndex
:
8
];
[
computeEncoder
setBytes
:
&
pooled_width
length
:
sizeof
(
int64_t
)
atIndex
:
9
];
[
computeEncoder
setBytes
:
&
spatial_scale_f
length
:
sizeof
(
float
)
atIndex
:
10
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
threadsPerBlock
)
{
tgSize
=
threadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
return
std
::
make_tuple
(
output
,
argmax
);
}
at
::
Tensor
roi_pool_backward_kernel
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
rois
,
const
at
::
Tensor
&
argmax
,
double
spatial_scale
,
int64_t
pooled_height
,
int64_t
pooled_width
,
int64_t
batch_size
,
int64_t
channels
,
int64_t
height
,
int64_t
width
)
{
using
namespace
at
::
native
::
mps
;
TORCH_CHECK
(
grad
.
is_mps
(),
"grad must be a MPS tensor"
);
TORCH_CHECK
(
rois
.
is_mps
(),
"rois must be a MPS tensor"
);
TORCH_CHECK
(
grad
.
scalar_type
()
!=
at
::
kHalf
,
"MPS does not support roi_pool backward with float16 inputs."
);
TORCH_CHECK
(
argmax
.
is_mps
(),
"argmax must be a MPS tensor"
);
at
::
TensorArg
grad_t
{
grad
,
"input"
,
1
},
rois_t
{
rois
,
"rois"
,
2
},
argmax_t
{
argmax
,
"argmax"
,
3
};
at
::
CheckedFrom
c
=
"roi_pool_backward_kernel"
;
at
::
checkAllSameGPU
(
c
,
{
grad_t
,
rois_t
,
argmax_t
});
at
::
checkAllSameType
(
c
,
{
grad_t
,
rois_t
});
float
spatial_scale_f
=
static_cast
<
float
>
(
spatial_scale
);
at
::
Tensor
grad_input
=
at
::
zeros
({
batch_size
,
channels
,
height
,
width
},
grad
.
options
());
if
(
grad
.
numel
()
==
0
)
{
return
grad_input
;
}
int64_t
n_stride
=
grad
.
stride
(
0
);
int64_t
c_stride
=
grad
.
stride
(
1
);
int64_t
h_stride
=
grad
.
stride
(
2
);
int64_t
w_stride
=
grad
.
stride
(
3
);
int64_t
output_size
=
grad
.
numel
();
at
::
globalContext
().
alertNotDeterministic
(
"roi_pool_backward_kernel"
);
auto
argmax_
=
argmax
.
contiguous
(),
rois_
=
rois
.
contiguous
();
id
<
MTLBuffer
>
inputBuffer
=
getMTLBufferStorage
(
grad
);
id
<
MTLBuffer
>
roisBuffer
=
getMTLBufferStorage
(
rois_
);
id
<
MTLBuffer
>
argmaxBuffer
=
getMTLBufferStorage
(
argmax_
);
id
<
MTLBuffer
>
outputBuffer
=
getMTLBufferStorage
(
grad_input
);
id
<
MTLDevice
>
device
=
MPSDevice
::
getInstance
()
->
device
();
MPSStream
*
mpsStream
=
getCurrentMPSStream
();
dispatch_sync
(
mpsStream
->
queue
(),
^
()
{
@autoreleasepool
{
id
<
MTLComputeCommandEncoder
>
computeEncoder
=
mpsStream
->
commandEncoder
();
MTLSize
threadgroupsPerGrid
=
MTLSizeMake
(
std
::
min
(
ceil_div
(
static_cast
<
int64_t
>
(
grad
.
numel
()),
static_cast
<
int64_t
>
(
512
)),
static_cast
<
int64_t
>
(
4096
)),
1
,
1
);
const
std
::
string
kernel
=
"roi_pool_backward_"
+
scalarToMetalTypeString
(
grad
.
scalar_type
());
id
<
MTLComputePipelineState
>
visionPSO
=
mps
::
visionPipelineState
(
device
,
kernel
);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler
().
beginProfileKernel
(
visionPSO
,
kernel
,
{
grad
,
rois_
,
argmax_
});
[
computeEncoder
setComputePipelineState
:
visionPSO
];
// [N, C, H, W]
[
computeEncoder
setBuffer
:
inputBuffer
offset
:
grad
.
storage_offset
()
*
grad
.
element_size
()
atIndex
:
0
];
[
computeEncoder
setBuffer
:
roisBuffer
offset
:
rois_
.
storage_offset
()
*
rois_
.
element_size
()
atIndex
:
1
];
[
computeEncoder
setBuffer
:
argmaxBuffer
offset
:
argmax_
.
storage_offset
()
*
argmax_
.
element_size
()
atIndex
:
2
];
[
computeEncoder
setBuffer
:
outputBuffer
offset
:
grad_input
.
storage_offset
()
*
grad_input
.
element_size
()
atIndex
:
3
];
[
computeEncoder
setBytes
:
&
output_size
length
:
sizeof
(
int64_t
)
atIndex
:
4
];
[
computeEncoder
setBytes
:
&
channels
length
:
sizeof
(
int64_t
)
atIndex
:
5
];
[
computeEncoder
setBytes
:
&
height
length
:
sizeof
(
int64_t
)
atIndex
:
6
];
[
computeEncoder
setBytes
:
&
width
length
:
sizeof
(
int64_t
)
atIndex
:
7
];
[
computeEncoder
setBytes
:
&
pooled_height
length
:
sizeof
(
int64_t
)
atIndex
:
8
];
[
computeEncoder
setBytes
:
&
pooled_width
length
:
sizeof
(
int64_t
)
atIndex
:
9
];
[
computeEncoder
setBytes
:
&
spatial_scale_f
length
:
sizeof
(
float
)
atIndex
:
10
];
[
computeEncoder
setBytes
:
&
n_stride
length
:
sizeof
(
int64_t
)
atIndex
:
11
];
[
computeEncoder
setBytes
:
&
c_stride
length
:
sizeof
(
int64_t
)
atIndex
:
12
];
[
computeEncoder
setBytes
:
&
h_stride
length
:
sizeof
(
int64_t
)
atIndex
:
13
];
[
computeEncoder
setBytes
:
&
w_stride
length
:
sizeof
(
int64_t
)
atIndex
:
14
];
// A threadGroup is equivalent to a cuda's block.
NSUInteger
tgSize
=
visionPSO
.
maxTotalThreadsPerThreadgroup
;
if
(
tgSize
>
threadsPerBlock
)
{
tgSize
=
threadsPerBlock
;
}
MTLSize
threadGroupSize
=
MTLSizeMake
(
tgSize
,
1
,
1
);
[
computeEncoder
dispatchThreadgroups
:
threadgroupsPerGrid
threadsPerThreadgroup
:
threadGroupSize
];
getMPSProfiler
().
endProfileKernel
(
visionPSO
);
}
});
return
grad_input
;
}
}
// namespace
TORCH_LIBRARY_IMPL
(
torchvision
,
MPS
,
m
)
{
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::roi_pool"
),
TORCH_FN
(
roi_pool_forward_kernel
));
m
.
impl
(
TORCH_SELECTIVE_NAME
(
"torchvision::_roi_pool_backward"
),
TORCH_FN
(
roi_pool_backward_kernel
));
}
}
// namespace ops
}
// namespace vision
torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp
View file @
cc26cd81
...
...
@@ -164,7 +164,7 @@ void qroi_align_forward_kernel_impl(
const
float
count
=
std
::
max
(
roi_bin_grid_h
*
roi_bin_grid_w
,
1
);
// e.g. = 4
// we want to precalculate indices and weights shared by all chanels,
// we want to precalculate indices and weights shared by all chan
n
els,
// this is the key point of optimization
std
::
vector
<
detail
::
PreCalc
<
float
>>
pre_calc
(
roi_bin_grid_h
*
roi_bin_grid_w
*
pooled_width
*
pooled_height
);
...
...
torchvision/csrc/ops/roi_align.cpp
View file @
cc26cd81
...
...
@@ -32,6 +32,31 @@ at::Tensor roi_align(
aligned
);
}
at
::
Tensor
roi_align_symint
(
const
at
::
Tensor
&
input
,
// Input feature map.
const
at
::
Tensor
&
rois
,
// List of ROIs to pool over.
double
spatial_scale
,
// The scale of the image features. ROIs will be
// scaled to this.
c10
::
SymInt
pooled_height
,
// The height of the pooled feature map.
c10
::
SymInt
pooled_width
,
// The width of the pooled feature
int64_t
sampling_ratio
,
// The number of points to sample in each bin
bool
aligned
)
// The flag for pixel shift
// along each axis.
{
C10_LOG_API_USAGE_ONCE
(
"torchvision.csrc.ops.roi_align.roi_align"
);
static
auto
op
=
c10
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchvision::roi_align"
,
""
)
.
typed
<
decltype
(
roi_align_symint
)
>
();
return
op
.
call
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
);
}
namespace
detail
{
at
::
Tensor
_roi_align_backward
(
...
...
@@ -64,13 +89,43 @@ at::Tensor _roi_align_backward(
aligned
);
}
at
::
Tensor
_roi_align_backward_symint
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
c10
::
SymInt
pooled_height
,
c10
::
SymInt
pooled_width
,
c10
::
SymInt
batch_size
,
c10
::
SymInt
channels
,
c10
::
SymInt
height
,
c10
::
SymInt
width
,
int64_t
sampling_ratio
,
bool
aligned
)
{
static
auto
op
=
c10
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchvision::_roi_align_backward"
,
""
)
.
typed
<
decltype
(
_roi_align_backward_symint
)
>
();
return
op
.
call
(
grad
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
batch_size
,
channels
,
height
,
width
,
sampling_ratio
,
aligned
);
}
}
// namespace detail
TORCH_LIBRARY_FRAGMENT
(
torchvision
,
m
)
{
m
.
def
(
TORCH_SELECTIVE_SCHEMA
(
"torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale,
i
nt pooled_height,
i
nt pooled_width, int sampling_ratio, bool aligned) -> Tensor"
));
"torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale,
SymI
nt pooled_height,
SymI
nt pooled_width, int sampling_ratio, bool aligned) -> Tensor"
));
m
.
def
(
TORCH_SELECTIVE_SCHEMA
(
"torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale,
i
nt pooled_height,
i
nt pooled_width,
i
nt batch_size,
i
nt channels,
i
nt height,
i
nt width, int sampling_ratio, bool aligned) -> Tensor"
));
"torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale,
SymI
nt pooled_height,
SymI
nt pooled_width,
SymI
nt batch_size,
SymI
nt channels,
SymI
nt height,
SymI
nt width, int sampling_ratio, bool aligned) -> Tensor"
));
}
}
// namespace ops
...
...
torchvision/csrc/ops/roi_align.h
View file @
cc26cd81
...
...
@@ -15,6 +15,15 @@ VISION_API at::Tensor roi_align(
int64_t
sampling_ratio
,
bool
aligned
);
VISION_API
at
::
Tensor
roi_align_symint
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
c10
::
SymInt
pooled_height
,
c10
::
SymInt
pooled_width
,
int64_t
sampling_ratio
,
bool
aligned
);
namespace
detail
{
at
::
Tensor
_roi_align_backward
(
...
...
@@ -30,6 +39,19 @@ at::Tensor _roi_align_backward(
int64_t
sampling_ratio
,
bool
aligned
);
at
::
Tensor
_roi_align_backward_symint
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
c10
::
SymInt
pooled_height
,
c10
::
SymInt
pooled_width
,
c10
::
SymInt
batch_size
,
c10
::
SymInt
channels
,
c10
::
SymInt
height
,
c10
::
SymInt
width
,
int64_t
sampling_ratio
,
bool
aligned
);
}
// namespace detail
}
// namespace ops
...
...
torchvision/datasets/__init__.py
View file @
cc26cd81
...
...
@@ -36,6 +36,7 @@ from .kitti import Kitti
from
.lfw
import
LFWPairs
,
LFWPeople
from
.lsun
import
LSUN
,
LSUNClass
from
.mnist
import
EMNIST
,
FashionMNIST
,
KMNIST
,
MNIST
,
QMNIST
from
.moving_mnist
import
MovingMNIST
from
.omniglot
import
Omniglot
from
.oxford_iiit_pet
import
OxfordIIITPet
from
.pcam
import
PCAM
...
...
@@ -126,4 +127,18 @@ __all__ = (
"SintelStereo"
,
"InStereo2k"
,
"ETH3DStereo"
,
"wrap_dataset_for_transforms_v2"
,
)
# We override current module's attributes to handle the import:
# from torchvision.datasets import wrap_dataset_for_transforms_v2
# without a cyclic error.
# Ref: https://peps.python.org/pep-0562/
def
__getattr__
(
name
):
if
name
in
(
"wrap_dataset_for_transforms_v2"
,):
from
torchvision.tv_tensors._dataset_wrapper
import
wrap_dataset_for_transforms_v2
return
wrap_dataset_for_transforms_v2
raise
AttributeError
(
f
"module
{
__name__
!
r
}
has no attribute
{
name
!
r
}
"
)
torchvision/datasets/_optical_flow.py
View file @
cc26cd81
...
...
@@ -3,6 +3,7 @@ import os
from
abc
import
ABC
,
abstractmethod
from
glob
import
glob
from
pathlib
import
Path
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -13,6 +14,10 @@ from .utils import _read_pfm, verify_str_arg
from
.vision
import
VisionDataset
T1
=
Tuple
[
Image
.
Image
,
Image
.
Image
,
Optional
[
np
.
ndarray
],
Optional
[
np
.
ndarray
]]
T2
=
Tuple
[
Image
.
Image
,
Image
.
Image
,
Optional
[
np
.
ndarray
]]
__all__
=
(
"KittiFlow"
,
"Sintel"
,
...
...
@@ -28,26 +33,26 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask
=
False
def
__init__
(
self
,
root
,
transforms
=
None
)
:
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
=
root
)
self
.
transforms
=
transforms
self
.
_flow_list
=
[]
self
.
_image_list
=
[]
self
.
_flow_list
:
List
[
str
]
=
[]
self
.
_image_list
:
List
[
List
[
str
]]
=
[]
def
_read_img
(
self
,
file_name
)
:
def
_read_img
(
self
,
file_name
:
str
)
->
Image
.
Image
:
img
=
Image
.
open
(
file_name
)
if
img
.
mode
!=
"RGB"
:
img
=
img
.
convert
(
"RGB"
)
return
img
@
abstractmethod
def
_read_flow
(
self
,
file_name
):
def
_read_flow
(
self
,
file_name
:
str
):
# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
pass
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Union
[
T1
,
T2
]
:
img1
=
self
.
_read_img
(
self
.
_image_list
[
index
][
0
])
img2
=
self
.
_read_img
(
self
.
_image_list
[
index
][
1
])
...
...
@@ -70,10 +75,10 @@ class FlowDataset(ABC, VisionDataset):
else
:
return
img1
,
img2
,
flow
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_image_list
)
def
__rmul__
(
self
,
v
)
:
def
__rmul__
(
self
,
v
:
int
)
->
torch
.
utils
.
data
.
ConcatDataset
:
return
torch
.
utils
.
data
.
ConcatDataset
([
self
]
*
v
)
...
...
@@ -118,7 +123,13 @@ class Sintel(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def
__init__
(
self
,
root
,
split
=
"train"
,
pass_name
=
"clean"
,
transforms
=
None
):
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
pass_name
:
str
=
"clean"
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
@@ -139,7 +150,7 @@ class Sintel(FlowDataset):
if
split
==
"train"
:
self
.
_flow_list
+=
sorted
(
glob
(
str
(
flow_root
/
scene
/
"*.flo"
)))
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Union
[
T1
,
T2
]
:
"""Return example at given index.
Args:
...
...
@@ -154,7 +165,7 @@ class Sintel(FlowDataset):
"""
return
super
().
__getitem__
(
index
)
def
_read_flow
(
self
,
file_name
)
:
def
_read_flow
(
self
,
file_name
:
str
)
->
np
.
ndarray
:
return
_read_flo
(
file_name
)
...
...
@@ -180,7 +191,7 @@ class KittiFlow(FlowDataset):
_has_builtin_flow_mask
=
True
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
)
:
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
@@ -200,7 +211,7 @@ class KittiFlow(FlowDataset):
if
split
==
"train"
:
self
.
_flow_list
=
sorted
(
glob
(
str
(
root
/
"flow_occ"
/
"*_10.png"
)))
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Union
[
T1
,
T2
]
:
"""Return example at given index.
Args:
...
...
@@ -215,7 +226,7 @@ class KittiFlow(FlowDataset):
"""
return
super
().
__getitem__
(
index
)
def
_read_flow
(
self
,
file_name
)
:
def
_read_flow
(
self
,
file_name
:
str
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
:
return
_read_16bits_png_with_flow_and_valid_mask
(
file_name
)
...
...
@@ -245,7 +256,7 @@ class FlyingChairs(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
)
:
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"val"
))
...
...
@@ -268,7 +279,7 @@ class FlyingChairs(FlowDataset):
self
.
_flow_list
+=
[
flows
[
i
]]
self
.
_image_list
+=
[[
images
[
2
*
i
],
images
[
2
*
i
+
1
]]]
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Union
[
T1
,
T2
]
:
"""Return example at given index.
Args:
...
...
@@ -283,7 +294,7 @@ class FlyingChairs(FlowDataset):
"""
return
super
().
__getitem__
(
index
)
def
_read_flow
(
self
,
file_name
)
:
def
_read_flow
(
self
,
file_name
:
str
)
->
np
.
ndarray
:
return
_read_flo
(
file_name
)
...
...
@@ -316,7 +327,14 @@ class FlyingThings3D(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def
__init__
(
self
,
root
,
split
=
"train"
,
pass_name
=
"clean"
,
camera
=
"left"
,
transforms
=
None
):
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
pass_name
:
str
=
"clean"
,
camera
:
str
=
"left"
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
@@ -359,7 +377,7 @@ class FlyingThings3D(FlowDataset):
self
.
_image_list
+=
[[
images
[
i
+
1
],
images
[
i
]]]
self
.
_flow_list
+=
[
flows
[
i
+
1
]]
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Union
[
T1
,
T2
]
:
"""Return example at given index.
Args:
...
...
@@ -374,7 +392,7 @@ class FlyingThings3D(FlowDataset):
"""
return
super
().
__getitem__
(
index
)
def
_read_flow
(
self
,
file_name
)
:
def
_read_flow
(
self
,
file_name
:
str
)
->
np
.
ndarray
:
return
_read_pfm
(
file_name
)
...
...
@@ -401,7 +419,7 @@ class HD1K(FlowDataset):
_has_builtin_flow_mask
=
True
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
)
:
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
@@ -426,10 +444,10 @@ class HD1K(FlowDataset):
"Could not find the HD1K images. Please make sure the directory structure is correct."
)
def
_read_flow
(
self
,
file_name
)
:
def
_read_flow
(
self
,
file_name
:
str
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
:
return
_read_16bits_png_with_flow_and_valid_mask
(
file_name
)
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Union
[
T1
,
T2
]
:
"""Return example at given index.
Args:
...
...
@@ -445,7 +463,7 @@ class HD1K(FlowDataset):
return
super
().
__getitem__
(
index
)
def
_read_flo
(
file_name
)
:
def
_read_flo
(
file_name
:
str
)
->
np
.
ndarray
:
"""Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
...
...
@@ -462,7 +480,7 @@ def _read_flo(file_name):
return
data
.
reshape
(
h
,
w
,
2
).
transpose
(
2
,
0
,
1
)
def
_read_16bits_png_with_flow_and_valid_mask
(
file_name
)
:
def
_read_16bits_png_with_flow_and_valid_mask
(
file_name
:
str
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
:
flow_and_valid
=
_read_png_16
(
file_name
).
to
(
torch
.
float32
)
flow
,
valid_flow_mask
=
flow_and_valid
[:
2
,
:,
:],
flow_and_valid
[
2
,
:,
:]
...
...
torchvision/datasets/_stereo_matching.py
View file @
cc26cd81
...
...
@@ -6,7 +6,7 @@ import shutil
from
abc
import
ABC
,
abstractmethod
from
glob
import
glob
from
pathlib
import
Path
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
cast
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
from
PIL
import
Image
...
...
@@ -14,6 +14,9 @@ from PIL import Image
from
.utils
import
_read_pfm
,
download_and_extract_archive
,
verify_str_arg
from
.vision
import
VisionDataset
T1
=
Tuple
[
Image
.
Image
,
Image
.
Image
,
Optional
[
np
.
ndarray
],
np
.
ndarray
]
T2
=
Tuple
[
Image
.
Image
,
Image
.
Image
,
Optional
[
np
.
ndarray
]]
__all__
=
()
_read_pfm_file
=
functools
.
partial
(
_read_pfm
,
slice_channels
=
1
)
...
...
@@ -24,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
_has_built_in_disparity_mask
=
False
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
"""
Args:
root(str): Root directory of the dataset.
...
...
@@ -58,7 +61,11 @@ class StereoMatchingDataset(ABC, VisionDataset):
img
=
img
.
convert
(
"RGB"
)
return
img
def
_scan_pairs
(
self
,
paths_left_pattern
:
str
,
paths_right_pattern
:
Optional
[
str
]
=
None
):
def
_scan_pairs
(
self
,
paths_left_pattern
:
str
,
paths_right_pattern
:
Optional
[
str
]
=
None
,
)
->
List
[
Tuple
[
str
,
Optional
[
str
]]]:
left_paths
=
list
(
sorted
(
glob
(
paths_left_pattern
)))
...
...
@@ -85,11 +92,11 @@ class StereoMatchingDataset(ABC, VisionDataset):
return
paths
@
abstractmethod
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
[
Optional
[
np
.
ndarray
],
Optional
[
np
.
ndarray
]]
:
# function that returns a disparity map and an occlusion map
pass
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
:
def
__getitem__
(
self
,
index
:
int
)
->
Union
[
T1
,
T2
]
:
"""Return example at given index.
Args:
...
...
@@ -120,7 +127,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
)
=
self
.
transforms
(
imgs
,
dsp_maps
,
valid_masks
)
if
self
.
_has_built_in_disparity_mask
or
valid_masks
[
0
]
is
not
None
:
return
imgs
[
0
],
imgs
[
1
],
dsp_maps
[
0
],
valid_masks
[
0
]
return
imgs
[
0
],
imgs
[
1
],
dsp_maps
[
0
],
cast
(
np
.
ndarray
,
valid_masks
[
0
]
)
else
:
return
imgs
[
0
],
imgs
[
1
],
dsp_maps
[
0
]
...
...
@@ -156,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"carla-highres"
...
...
@@ -171,13 +178,13 @@ class CarlaStereo(StereoMatchingDataset):
disparities
=
self
.
_scan_pairs
(
left_disparity_pattern
,
right_disparity_pattern
)
self
.
_disparities
=
disparities
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
[
np
.
ndarray
,
None
]
:
disparity_map
=
_read_pfm_file
(
file_path
)
disparity_map
=
np
.
abs
(
disparity_map
)
# ensure that the disparity is positive
valid_mask
=
None
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
1
:
"""Return example at given index.
Args:
...
...
@@ -189,7 +196,7 @@ class CarlaStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T1
,
super
().
__getitem__
(
index
)
)
class
Kitti2012Stereo
(
StereoMatchingDataset
):
...
...
@@ -233,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
@@ -250,7 +257,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
else
:
self
.
_disparities
=
list
((
None
,
None
)
for
_
in
self
.
_images
)
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
[
Optional
[
np
.
ndarray
],
None
]
:
# test split has no disparity maps
if
file_path
is
None
:
return
None
,
None
...
...
@@ -261,7 +268,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
valid_mask
=
None
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
1
:
"""Return example at given index.
Args:
...
...
@@ -274,7 +281,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T1
,
super
().
__getitem__
(
index
)
)
class
Kitti2015Stereo
(
StereoMatchingDataset
):
...
...
@@ -321,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
@@ -338,7 +345,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
else
:
self
.
_disparities
=
list
((
None
,
None
)
for
_
in
self
.
_images
)
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
[
Optional
[
np
.
ndarray
],
None
]
:
# test split has no disparity maps
if
file_path
is
None
:
return
None
,
None
...
...
@@ -349,7 +356,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
valid_mask
=
None
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
1
:
"""Return example at given index.
Args:
...
...
@@ -362,7 +369,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T1
,
super
().
__getitem__
(
index
)
)
class
Middlebury2014Stereo
(
StereoMatchingDataset
):
...
...
@@ -417,9 +424,9 @@ class Middlebury2014Stereo(StereoMatchingDataset):
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
calibration (string, optional): Wether or not to use the calibrated (default) or uncalibrated scenes.
calibration (string, optional): W
h
ether or not to use the calibrated (default) or uncalibrated scenes.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
download (boolean, optional): Wether or not to download the dataset in the ``root`` directory.
download (boolean, optional): W
h
ether or not to download the dataset in the ``root`` directory.
"""
splits
=
{
...
...
@@ -479,7 +486,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
use_ambient_views
:
bool
=
False
,
transforms
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
)
->
None
:
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
,
"additional"
))
...
...
@@ -558,7 +565,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
file_path
=
random
.
choice
(
ambient_file_paths
)
# type: ignore
return
super
().
_read_img
(
file_path
)
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Union
[
Tuple
[
None
,
None
],
Tuple
[
np
.
ndarray
,
np
.
ndarray
]]
:
# test split has not disparity maps
if
file_path
is
None
:
return
None
,
None
...
...
@@ -569,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
valid_mask
=
(
disparity_map
>
0
).
squeeze
(
0
)
# mask out invalid disparities
return
disparity_map
,
valid_mask
def
_download_dataset
(
self
,
root
:
str
):
def
_download_dataset
(
self
,
root
:
str
)
->
None
:
base_url
=
"https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
root
=
Path
(
root
)
/
"Middlebury2014"
...
...
@@ -608,7 +615,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
# cleanup MiddEval3 directory
shutil
.
rmtree
(
str
(
root
/
"MiddEval3"
))
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
2
:
"""Return example at given index.
Args:
...
...
@@ -619,7 +626,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` for `split=test`.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T2
,
super
().
__getitem__
(
index
)
)
class
CREStereo
(
StereoMatchingDataset
):
...
...
@@ -670,7 +677,7 @@ class CREStereo(StereoMatchingDataset):
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
,
):
)
->
None
:
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"CREStereo"
...
...
@@ -688,14 +695,14 @@ class CREStereo(StereoMatchingDataset):
disparities
=
self
.
_scan_pairs
(
left_disparity_pattern
,
right_disparity_pattern
)
self
.
_disparities
+=
disparities
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
[
np
.
ndarray
,
None
]
:
disparity_map
=
np
.
asarray
(
Image
.
open
(
file_path
),
dtype
=
np
.
float32
)
# unsqueeze the disparity map into (C, H, W) format
disparity_map
=
disparity_map
[
None
,
:,
:]
/
32.0
valid_mask
=
None
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
1
:
"""Return example at given index.
Args:
...
...
@@ -707,13 +714,13 @@ class CREStereo(StereoMatchingDataset):
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T1
,
super
().
__getitem__
(
index
)
)
class
FallingThingsStereo
(
StereoMatchingDataset
):
"""`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
The dataset is expected to have the following structre: ::
The dataset is expected to have the following struct
u
re: ::
root
FallingThings
...
...
@@ -755,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def
__init__
(
self
,
root
:
str
,
variant
:
str
=
"single"
,
transforms
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root
:
str
,
variant
:
str
=
"single"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"FallingThings"
...
...
@@ -782,14 +789,14 @@ class FallingThingsStereo(StereoMatchingDataset):
right_disparity_pattern
=
str
(
root
/
s
/
split_prefix
[
s
]
/
"*.right.depth.png"
)
self
.
_disparities
+=
self
.
_scan_pairs
(
left_disparity_pattern
,
right_disparity_pattern
)
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
[
np
.
ndarray
,
None
]
:
# (H, W) image
depth
=
np
.
asarray
(
Image
.
open
(
file_path
))
# as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
# in order to extract disparity from depth maps
camera_settings_path
=
Path
(
file_path
).
parent
/
"_camera_settings.json"
with
open
(
camera_settings_path
,
"r"
)
as
f
:
# inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_consta
t
nt)
# inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
intrinsics
=
json
.
load
(
f
)
focal
=
intrinsics
[
"camera_settings"
][
0
][
"intrinsic_settings"
][
"fx"
]
baseline
,
pixel_constant
=
6
,
100
# pixel constant is inverted
...
...
@@ -799,7 +806,7 @@ class FallingThingsStereo(StereoMatchingDataset):
valid_mask
=
None
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
1
:
"""Return example at given index.
Args:
...
...
@@ -811,14 +818,14 @@ class FallingThingsStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T1
,
super
().
__getitem__
(
index
)
)
class
SceneFlowStereo
(
StereoMatchingDataset
):
"""Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
The dataset is expected to have the following structre: ::
The dataset is expected to have the following struct
u
re: ::
root
SceneFlow
...
...
@@ -874,7 +881,7 @@ class SceneFlowStereo(StereoMatchingDataset):
variant
:
str
=
"FlyingThings3D"
,
pass_name
:
str
=
"clean"
,
transforms
:
Optional
[
Callable
]
=
None
,
):
)
->
None
:
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"SceneFlow"
...
...
@@ -905,13 +912,13 @@ class SceneFlowStereo(StereoMatchingDataset):
right_disparity_pattern
=
str
(
root
/
"disparity"
/
prefix_directories
[
variant
]
/
"right"
/
"*.pfm"
)
self
.
_disparities
+=
self
.
_scan_pairs
(
left_disparity_pattern
,
right_disparity_pattern
)
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
[
np
.
ndarray
,
None
]
:
disparity_map
=
_read_pfm_file
(
file_path
)
disparity_map
=
np
.
abs
(
disparity_map
)
# ensure that the disparity is positive
valid_mask
=
None
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
1
:
"""Return example at given index.
Args:
...
...
@@ -923,7 +930,7 @@ class SceneFlowStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T1
,
super
().
__getitem__
(
index
)
)
class
SintelStereo
(
StereoMatchingDataset
):
...
...
@@ -973,7 +980,7 @@ class SintelStereo(StereoMatchingDataset):
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
pass_name
:
str
=
"final"
,
transforms
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root
:
str
,
pass_name
:
str
=
"final"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"final"
,
"clean"
,
"both"
))
...
...
@@ -1014,7 +1021,7 @@ class SintelStereo(StereoMatchingDataset):
return
occlusion_path
,
outofframe_path
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Union
[
Tuple
[
None
,
None
],
Tuple
[
np
.
ndarray
,
np
.
ndarray
]]
:
if
file_path
is
None
:
return
None
,
None
...
...
@@ -1024,7 +1031,7 @@ class SintelStereo(StereoMatchingDataset):
disparity_map
=
r
*
4
+
g
/
(
2
**
6
)
+
b
/
(
2
**
14
)
# reshape into (C, H, W) format
disparity_map
=
np
.
transpose
(
disparity_map
,
(
2
,
0
,
1
))
# find the appropiate file paths
# find the approp
r
iate file paths
occlued_mask_path
,
out_of_frame_mask_path
=
self
.
_get_occlussion_mask_paths
(
file_path
)
# occlusion masks
valid_mask
=
np
.
asarray
(
Image
.
open
(
occlued_mask_path
))
==
0
...
...
@@ -1034,7 +1041,7 @@ class SintelStereo(StereoMatchingDataset):
valid_mask
=
np
.
logical_and
(
off_mask
,
valid_mask
)
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
2
:
"""Return example at given index.
Args:
...
...
@@ -1045,13 +1052,13 @@ class SintelStereo(StereoMatchingDataset):
The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
the valid_mask is a numpy array of shape (H, W).
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T2
,
super
().
__getitem__
(
index
)
)
class
InStereo2k
(
StereoMatchingDataset
):
"""`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.
The dataset is expected to have the following structre: ::
The dataset is expected to have the following struct
u
re: ::
root
InStereo2k
...
...
@@ -1080,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"InStereo2k"
/
split
...
...
@@ -1095,14 +1102,14 @@ class InStereo2k(StereoMatchingDataset):
right_disparity_pattern
=
str
(
root
/
"*"
/
"right_disp.png"
)
self
.
_disparities
=
self
.
_scan_pairs
(
left_disparity_pattern
,
right_disparity_pattern
)
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
[
np
.
ndarray
,
None
]
:
disparity_map
=
np
.
asarray
(
Image
.
open
(
file_path
),
dtype
=
np
.
float32
)
# unsqueeze disparity to (C, H, W)
disparity_map
=
disparity_map
[
None
,
:,
:]
/
1024.0
valid_mask
=
None
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
1
:
"""Return example at given index.
Args:
...
...
@@ -1114,7 +1121,7 @@ class InStereo2k(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T1
,
super
().
__getitem__
(
index
)
)
class
ETH3DStereo
(
StereoMatchingDataset
):
...
...
@@ -1169,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset):
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
@@ -1189,7 +1196,7 @@ class ETH3DStereo(StereoMatchingDataset):
disparity_pattern
=
str
(
root
/
anot_dir
/
"*"
/
"disp0GT.pfm"
)
self
.
_disparities
=
self
.
_scan_pairs
(
disparity_pattern
,
None
)
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
def
_read_disparity
(
self
,
file_path
:
str
)
->
Union
[
Tuple
[
None
,
None
],
Tuple
[
np
.
ndarray
,
np
.
ndarray
]]
:
# test split has no disparity maps
if
file_path
is
None
:
return
None
,
None
...
...
@@ -1201,7 +1208,7 @@ class ETH3DStereo(StereoMatchingDataset):
valid_mask
=
np
.
asarray
(
valid_mask
).
astype
(
bool
)
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
T
uple
:
def
__getitem__
(
self
,
index
:
int
)
->
T
2
:
"""Return example at given index.
Args:
...
...
@@ -1214,4 +1221,4 @@ class ETH3DStereo(StereoMatchingDataset):
generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return
super
().
__getitem__
(
index
)
return
cast
(
T2
,
super
().
__getitem__
(
index
)
)
torchvision/datasets/celeba.py
View file @
cc26cd81
...
...
@@ -23,10 +23,10 @@ class CelebA(VisionDataset):
or ``landmarks``. Can also be a list to output a tuple with all specified target types.
The targets represent:
- ``attr`` (
np.array
shape=(40,) dtype=int): binary (0, 1) labels for attributes
- ``attr`` (
Tensor
shape=(40,) dtype=int): binary (0, 1) labels for attributes
- ``identity`` (int): label for each person (data points with the same identity are the same person)
- ``bbox`` (
np.array
shape=(4,) dtype=int): bounding box (x, y, width, height)
- ``landmarks`` (
np.array
shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
- ``bbox`` (
Tensor
shape=(4,) dtype=int): bounding box (x, y, width, height)
- ``landmarks`` (
Tensor
shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
Defaults to ``attr``. If empty, ``None`` will be returned as target.
...
...
@@ -41,7 +41,7 @@ class CelebA(VisionDataset):
"""
base_folder
=
"celeba"
# There currently does not appear to be a easy way to extract 7z in python (without introducing additional
# There currently does not appear to be a
n
easy way to extract 7z in python (without introducing additional
# dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
# right now.
file_list
=
[
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment