#pragma once #include #include "cutlass/numeric_types.h" #include "helper.h" template struct cutlass_dtype { using type = T; }; template <> struct cutlass_dtype { using type = cutlass::half_t; }; template <> struct cutlass_dtype { using type = cutlass::bfloat16_t; }; template <> struct cutlass_dtype<__nv_fp8_e4m3> { using type = cutlass::float_e4m3_t; }; template <> struct cutlass_dtype<__nv_fp8_e5m2> { using type = cutlass::float_e5m2_t; }; template using cutlass_dtype_t = typename cutlass_dtype::type;