Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
668348ed
Unverified
Commit
668348ed
authored
Oct 30, 2023
by
Nicolas Hug
Committed by
GitHub
Oct 30, 2023
Browse files
PSRoiPool: SymInt support + meta-implem (#8062)
parent
85c586c6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
124 additions
and
30 deletions
+124
-30
torchvision/_meta_registrations.py
torchvision/_meta_registrations.py
+34
-0
torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp
torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp
+28
-28
torchvision/csrc/ops/ps_roi_pool.cpp
torchvision/csrc/ops/ps_roi_pool.cpp
+43
-2
torchvision/csrc/ops/ps_roi_pool.h
torchvision/csrc/ops/ps_roi_pool.h
+19
-0
No files found.
torchvision/_meta_registrations.py
View file @
668348ed
...
@@ -126,6 +126,40 @@ def meta_roi_pool_backward(
...
@@ -126,6 +126,40 @@ def meta_roi_pool_backward(
return
grad
.
new_empty
((
batch_size
,
channels
,
height
,
width
))
return
grad
.
new_empty
((
batch_size
,
channels
,
height
,
width
))
@
register_meta
(
"ps_roi_pool"
)
def
meta_ps_roi_pool
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
):
torch
.
_check
(
rois
.
size
(
1
)
==
5
,
lambda
:
"rois must have shape as Tensor[K, 5]"
)
torch
.
_check
(
input
.
dtype
==
rois
.
dtype
,
lambda
:
(
"Expected tensor for input to have the same type as tensor for rois; "
f
"but type
{
input
.
dtype
}
does not equal
{
rois
.
dtype
}
"
),
)
channels
=
input
.
size
(
1
)
torch
.
_check
(
channels
%
(
pooled_height
*
pooled_width
)
==
0
,
"input channels must be a multiple of pooling height * pooling width"
,
)
num_rois
=
rois
.
size
(
0
)
out_size
=
(
num_rois
,
channels
//
(
pooled_height
*
pooled_width
),
pooled_height
,
pooled_width
)
return
input
.
new_empty
(
out_size
),
torch
.
empty
(
out_size
,
device
=
"meta"
,
dtype
=
torch
.
int32
)
@
register_meta
(
"_ps_roi_pool_backward"
)
def
meta_ps_roi_pool_backward
(
grad
,
rois
,
channel_mapping
,
spatial_scale
,
pooled_height
,
pooled_width
,
batch_size
,
channels
,
height
,
width
):
torch
.
_check
(
grad
.
dtype
==
rois
.
dtype
,
lambda
:
(
"Expected tensor for grad to have the same type as tensor for rois; "
f
"but type
{
grad
.
dtype
}
does not equal
{
rois
.
dtype
}
"
),
)
return
grad
.
new_empty
((
batch_size
,
channels
,
height
,
width
))
@
torch
.
_custom_ops
.
impl_abstract
(
"torchvision::nms"
)
@
torch
.
_custom_ops
.
impl_abstract
(
"torchvision::nms"
)
def
meta_nms
(
dets
,
scores
,
iou_threshold
):
def
meta_nms
(
dets
,
scores
,
iou_threshold
):
torch
.
_check
(
dets
.
dim
()
==
2
,
lambda
:
f
"boxes should be a 2d tensor, got
{
dets
.
dim
()
}
D"
)
torch
.
_check
(
dets
.
dim
()
==
2
,
lambda
:
f
"boxes should be a 2d tensor, got
{
dets
.
dim
()
}
D"
)
...
...
torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp
View file @
668348ed
...
@@ -15,15 +15,15 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
...
@@ -15,15 +15,15 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
const
torch
::
autograd
::
Variable
&
input
,
const
torch
::
autograd
::
Variable
&
input
,
const
torch
::
autograd
::
Variable
&
rois
,
const
torch
::
autograd
::
Variable
&
rois
,
double
spatial_scale
,
double
spatial_scale
,
int64_
t
pooled_height
,
c10
::
SymIn
t
pooled_height
,
int64_
t
pooled_width
)
{
c10
::
SymIn
t
pooled_width
)
{
ctx
->
saved_data
[
"spatial_scale"
]
=
spatial_scale
;
ctx
->
saved_data
[
"spatial_scale"
]
=
spatial_scale
;
ctx
->
saved_data
[
"pooled_height"
]
=
pooled_height
;
ctx
->
saved_data
[
"pooled_height"
]
=
pooled_height
;
ctx
->
saved_data
[
"pooled_width"
]
=
pooled_width
;
ctx
->
saved_data
[
"pooled_width"
]
=
pooled_width
;
ctx
->
saved_data
[
"input_shape"
]
=
input
.
sizes
();
ctx
->
saved_data
[
"input_shape"
]
=
input
.
sym_
sizes
();
at
::
AutoDispatchBelowADInplaceOrView
g
;
at
::
AutoDispatchBelowADInplaceOrView
g
;
auto
result
=
auto
result
=
ps_roi_pool_symint
(
ps_roi_pool
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
);
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
);
auto
output
=
std
::
get
<
0
>
(
result
);
auto
output
=
std
::
get
<
0
>
(
result
);
auto
channel_mapping
=
std
::
get
<
1
>
(
result
);
auto
channel_mapping
=
std
::
get
<
1
>
(
result
);
...
@@ -40,18 +40,18 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
...
@@ -40,18 +40,18 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
auto
saved
=
ctx
->
get_saved_variables
();
auto
saved
=
ctx
->
get_saved_variables
();
auto
rois
=
saved
[
0
];
auto
rois
=
saved
[
0
];
auto
channel_mapping
=
saved
[
1
];
auto
channel_mapping
=
saved
[
1
];
auto
input_shape
=
ctx
->
saved_data
[
"input_shape"
].
to
Int
List
();
auto
input_shape
=
ctx
->
saved_data
[
"input_shape"
].
toList
();
auto
grad_in
=
detail
::
_ps_roi_pool_backward
(
auto
grad_in
=
detail
::
_ps_roi_pool_backward
_symint
(
grad_output
[
0
],
grad_output
[
0
],
rois
,
rois
,
channel_mapping
,
channel_mapping
,
ctx
->
saved_data
[
"spatial_scale"
].
toDouble
(),
ctx
->
saved_data
[
"spatial_scale"
].
toDouble
(),
ctx
->
saved_data
[
"pooled_height"
].
toInt
(),
ctx
->
saved_data
[
"pooled_height"
].
to
Sym
Int
(),
ctx
->
saved_data
[
"pooled_width"
].
toInt
(),
ctx
->
saved_data
[
"pooled_width"
].
to
Sym
Int
(),
input_shape
[
0
],
input_shape
[
0
]
.
get
().
toSymInt
()
,
input_shape
[
1
],
input_shape
[
1
]
.
get
().
toSymInt
()
,
input_shape
[
2
],
input_shape
[
2
]
.
get
().
toSymInt
()
,
input_shape
[
3
]);
input_shape
[
3
]
.
get
().
toSymInt
()
);
return
{
return
{
grad_in
,
grad_in
,
...
@@ -72,14 +72,14 @@ class PSROIPoolBackwardFunction
...
@@ -72,14 +72,14 @@ class PSROIPoolBackwardFunction
const
torch
::
autograd
::
Variable
&
rois
,
const
torch
::
autograd
::
Variable
&
rois
,
const
torch
::
autograd
::
Variable
&
channel_mapping
,
const
torch
::
autograd
::
Variable
&
channel_mapping
,
double
spatial_scale
,
double
spatial_scale
,
int64_
t
pooled_height
,
c10
::
SymIn
t
pooled_height
,
int64_
t
pooled_width
,
c10
::
SymIn
t
pooled_width
,
int64_
t
batch_size
,
c10
::
SymIn
t
batch_size
,
int64_
t
channels
,
c10
::
SymIn
t
channels
,
int64_
t
height
,
c10
::
SymIn
t
height
,
int64_
t
width
)
{
c10
::
SymIn
t
width
)
{
at
::
AutoDispatchBelowADInplaceOrView
g
;
at
::
AutoDispatchBelowADInplaceOrView
g
;
auto
grad_in
=
detail
::
_ps_roi_pool_backward
(
auto
grad_in
=
detail
::
_ps_roi_pool_backward
_symint
(
grad
,
grad
,
rois
,
rois
,
channel_mapping
,
channel_mapping
,
...
@@ -105,8 +105,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
...
@@ -105,8 +105,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
rois
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
double
spatial_scale
,
int64_
t
pooled_height
,
c10
::
SymIn
t
pooled_height
,
int64_
t
pooled_width
)
{
c10
::
SymIn
t
pooled_width
)
{
auto
result
=
PSROIPoolFunction
::
apply
(
auto
result
=
PSROIPoolFunction
::
apply
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
);
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
);
...
@@ -118,12 +118,12 @@ at::Tensor ps_roi_pool_backward_autograd(
...
@@ -118,12 +118,12 @@ at::Tensor ps_roi_pool_backward_autograd(
const
at
::
Tensor
&
rois
,
const
at
::
Tensor
&
rois
,
const
at
::
Tensor
&
channel_mapping
,
const
at
::
Tensor
&
channel_mapping
,
double
spatial_scale
,
double
spatial_scale
,
int64_
t
pooled_height
,
c10
::
SymIn
t
pooled_height
,
int64_
t
pooled_width
,
c10
::
SymIn
t
pooled_width
,
int64_
t
batch_size
,
c10
::
SymIn
t
batch_size
,
int64_
t
channels
,
c10
::
SymIn
t
channels
,
int64_
t
height
,
c10
::
SymIn
t
height
,
int64_
t
width
)
{
c10
::
SymIn
t
width
)
{
return
PSROIPoolBackwardFunction
::
apply
(
return
PSROIPoolBackwardFunction
::
apply
(
grad
,
grad
,
rois
,
rois
,
...
...
torchvision/csrc/ops/ps_roi_pool.cpp
View file @
668348ed
...
@@ -20,6 +20,19 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
...
@@ -20,6 +20,19 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
return
op
.
call
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
);
return
op
.
call
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
);
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
ps_roi_pool_symint
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
c10
::
SymInt
pooled_height
,
c10
::
SymInt
pooled_width
)
{
C10_LOG_API_USAGE_ONCE
(
"torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"
);
static
auto
op
=
c10
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchvision::ps_roi_pool"
,
""
)
.
typed
<
decltype
(
ps_roi_pool_symint
)
>
();
return
op
.
call
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
);
}
namespace
detail
{
namespace
detail
{
at
::
Tensor
_ps_roi_pool_backward
(
at
::
Tensor
_ps_roi_pool_backward
(
...
@@ -50,13 +63,41 @@ at::Tensor _ps_roi_pool_backward(
...
@@ -50,13 +63,41 @@ at::Tensor _ps_roi_pool_backward(
width
);
width
);
}
}
at
::
Tensor
_ps_roi_pool_backward_symint
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
rois
,
const
at
::
Tensor
&
channel_mapping
,
double
spatial_scale
,
c10
::
SymInt
pooled_height
,
c10
::
SymInt
pooled_width
,
c10
::
SymInt
batch_size
,
c10
::
SymInt
channels
,
c10
::
SymInt
height
,
c10
::
SymInt
width
)
{
static
auto
op
=
c10
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchvision::_ps_roi_pool_backward"
,
""
)
.
typed
<
decltype
(
_ps_roi_pool_backward_symint
)
>
();
return
op
.
call
(
grad
,
rois
,
channel_mapping
,
spatial_scale
,
pooled_height
,
pooled_width
,
batch_size
,
channels
,
height
,
width
);
}
}
// namespace detail
}
// namespace detail
TORCH_LIBRARY_FRAGMENT
(
torchvision
,
m
)
{
TORCH_LIBRARY_FRAGMENT
(
torchvision
,
m
)
{
m
.
def
(
TORCH_SELECTIVE_SCHEMA
(
m
.
def
(
TORCH_SELECTIVE_SCHEMA
(
"torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale,
i
nt pooled_height,
i
nt pooled_width) -> (Tensor, Tensor)"
));
"torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale,
SymI
nt pooled_height,
SymI
nt pooled_width) -> (Tensor, Tensor)"
));
m
.
def
(
TORCH_SELECTIVE_SCHEMA
(
m
.
def
(
TORCH_SELECTIVE_SCHEMA
(
"torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale,
i
nt pooled_height,
i
nt pooled_width,
i
nt batch_size,
i
nt channels,
i
nt height,
i
nt width) -> Tensor"
));
"torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale,
SymI
nt pooled_height,
SymI
nt pooled_width,
SymI
nt batch_size,
SymI
nt channels,
SymI
nt height,
SymI
nt width) -> Tensor"
));
}
}
}
// namespace ops
}
// namespace ops
...
...
torchvision/csrc/ops/ps_roi_pool.h
View file @
668348ed
...
@@ -13,6 +13,13 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
...
@@ -13,6 +13,13 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
int64_t
pooled_height
,
int64_t
pooled_height
,
int64_t
pooled_width
);
int64_t
pooled_width
);
VISION_API
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
ps_roi_pool_symint
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
rois
,
double
spatial_scale
,
c10
::
SymInt
pooled_height
,
c10
::
SymInt
pooled_width
);
namespace
detail
{
namespace
detail
{
at
::
Tensor
_ps_roi_pool_backward
(
at
::
Tensor
_ps_roi_pool_backward
(
...
@@ -27,6 +34,18 @@ at::Tensor _ps_roi_pool_backward(
...
@@ -27,6 +34,18 @@ at::Tensor _ps_roi_pool_backward(
int64_t
height
,
int64_t
height
,
int64_t
width
);
int64_t
width
);
at
::
Tensor
_ps_roi_pool_backward_symint
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
rois
,
const
at
::
Tensor
&
channel_mapping
,
double
spatial_scale
,
c10
::
SymInt
pooled_height
,
c10
::
SymInt
pooled_width
,
c10
::
SymInt
batch_size
,
c10
::
SymInt
channels
,
c10
::
SymInt
height
,
c10
::
SymInt
width
);
}
// namespace detail
}
// namespace detail
}
// namespace ops
}
// namespace ops
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment