CUDATypeConversion.cuh 534 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#pragma once

#include <ATen/Half.h>
#include <ATen/cuda/CUDAHalf.cuh>

// Type traits to convert types to CUDA-specific types. Used primarily to
// convert at::Half to CUDA's half type. This makes the conversion explicit.

// Disambiguate from whatever is in aten
namespace apex { namespace cuda {
template <typename T>
struct TypeConversion {
  using type = T;
};

template <>
struct TypeConversion<at::Half> {
  using type = half;
};

template <typename T>
using type = typename TypeConversion<T>::type;
}} // namespace apex::cuda