#include "../roi_align.h" #include #include namespace vision { namespace ops { namespace { at::Tensor roi_align_autocast( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio, bool aligned) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); return roi_align( at::autocast::cached_cast(at::kFloat, input), at::autocast::cached_cast(at::kFloat, rois), spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned) .to(input.scalar_type()); } } // namespace TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("roi_align", roi_align_autocast); } } // namespace ops } // namespace vision