#include "../nms.h" #include #include #include namespace vision { namespace ops { namespace { template at::Tensor nms_autocast( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); return nms( at::autocast::cached_cast(at::kFloat, dets, device_type), at::autocast::cached_cast(at::kFloat, scores, device_type), iou_threshold); } } // namespace TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN( (nms_autocast))); } TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN( (nms_autocast))); } } // namespace ops } // namespace vision