#pragma once #include #include // Type traits to convert types to CUDA-specific types. Used primarily to // convert at::Half to CUDA's half type. This makes the conversion explicit. // Disambiguate from whatever is in aten namespace apex { namespace cuda { template struct TypeConversion { using type = T; }; template <> struct TypeConversion { using type = half; }; template using type = typename TypeConversion::type; }} // namespace apex::cuda