// SPDX-License-Identifier: MIT #pragma once #include #include "aiter_hip_common.h" // FIXME: use isGPUArch in torch2.6, not support torch_fp8 = Float8_e4m3fnuz for now!!!! //const auto torch_fp8 = false? at::ScalarType::Float8_e4m3fnuz : at::ScalarType::Float8_e4m3fn; const constexpr auto torch_fp8 = at::ScalarType::Float8_e4m3fn; // clang-format off template struct t2ck; template <> struct t2ck { using type = ck_tile::fp32_t; }; template <> struct t2ck { using type = ck_tile::fp16_t; }; template <> struct t2ck { using type = ck_tile::bf16_t; }; template <> struct t2ck { using type = ck_tile::index_t; }; template <> struct t2ck { using type = ck_tile::int8_t; }; // clang-format on // common utility functions #define FOREACH_BUFFER_TORCH_TYPE_MAP(F) \ F("fp32", torch::kFloat) \ F("fp16", torch::kHalf) \ F("bf16", torch::kBFloat16) \ F("int32", torch::kInt32) \ F("int8", torch::kInt8) \ F("uint8", torch::kUInt8) \ F("fp8", torch::kFloat8_e4m3fn) inline std::string torchDTypeToStr(caffe2::TypeMeta dtype) { #define TYPE_CASE(type, torch_type) \ case torch_type: \ { \ return type; \ } switch (dtype.toScalarType()) { FOREACH_BUFFER_TORCH_TYPE_MAP(TYPE_CASE); default: throw std::runtime_error("CKPyInterface: Unsupported data type " + std::to_string((int8_t)(dtype.toScalarType()))); } #undef TYPE_CASE }