Unverified Commit 190a5f8a authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

compiles (#2646)

parent bb88c452
...@@ -49,8 +49,8 @@ at::Tensor ROIAlign_autocast( ...@@ -49,8 +49,8 @@ at::Tensor ROIAlign_autocast(
const bool aligned) { const bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return roi_align( return roi_align(
autocast::_cast(at::kFloat, input), at::autocast::cached_cast(at::kFloat, input),
autocast::_cast(at::kFloat, rois), at::autocast::cached_cast(at::kFloat, rois),
spatial_scale, spatial_scale,
pooled_height, pooled_height,
pooled_width, pooled_width,
......
#pragma once #pragma once
#if defined(WITH_CUDA) || defined(WITH_HIP) #if defined(WITH_CUDA) || defined(WITH_HIP)
namespace autocast { #include <ATen/autocast_mode.h>
inline bool is_eligible(const at::Tensor& arg) {
return (
arg.is_cuda() && arg.is_floating_point() &&
(arg.scalar_type() != at::kDouble));
}
// Overload to catch Tensor args
inline at::Tensor _cast(at::ScalarType to_type, const at::Tensor& arg) {
if (is_eligible(arg) && (arg.scalar_type() != to_type)) {
return arg.to(to_type);
} else {
return arg;
}
}
// Template to catch non-Tensor args
template <typename T>
inline T _cast(at::ScalarType to_type, T arg) {
return arg;
}
} // namespace autocast
#endif #endif
...@@ -28,8 +28,8 @@ at::Tensor nms_autocast( ...@@ -28,8 +28,8 @@ at::Tensor nms_autocast(
const double iou_threshold) { const double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return nms( return nms(
autocast::_cast(at::kFloat, dets), at::autocast::cached_cast(at::kFloat, dets),
autocast::_cast(at::kFloat, scores), at::autocast::cached_cast(at::kFloat, scores),
iou_threshold); iou_threshold);
} }
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment