utils.hpp 593 Bytes
Newer Older
1
2
#pragma once

3
#include <torch/extension.h>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include "cutlass/numeric_types.h"
#include "helper.h"

template <typename T>
struct cutlass_dtype {
  using type = T;
};

template <>
struct cutlass_dtype<half> {
  using type = cutlass::half_t;
};

template <>
struct cutlass_dtype<nv_bfloat16> {
  using type = cutlass::bfloat16_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e4m3> {
  using type = cutlass::float_e4m3_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e5m2> {
  using type = cutlass::float_e5m2_t;
};

template <typename T>
Zeyu WANG's avatar
Zeyu WANG committed
33
using cutlass_dtype_t = typename cutlass_dtype<T>::type;