py_itfs_common.h 1.57 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
// SPDX-License-Identifier: MIT
 

#pragma once
#include <torch/all.h>
#include "aiter_hip_common.h"
// FIXME: use isGPUArch in torch2.6, not support torch_fp8 = Float8_e4m3fnuz for now!!!!
//const auto torch_fp8 = false? at::ScalarType::Float8_e4m3fnuz : at::ScalarType::Float8_e4m3fn;
const constexpr auto torch_fp8 = at::ScalarType::Float8_e4m3fn;

// clang-format off
template <typename T> struct t2ck;
template <> struct t2ck<float> { using type = ck_tile::fp32_t; };
template <> struct t2ck<c10::Half> { using type = ck_tile::fp16_t; };
template <> struct t2ck<c10::BFloat16> { using type = ck_tile::bf16_t; };
template <> struct t2ck<int32_t> { using type = ck_tile::index_t; };
template <> struct t2ck<int8_t> { using type = ck_tile::int8_t; };
// clang-format on

// common utility functions
#define FOREACH_BUFFER_TORCH_TYPE_MAP(F) \
    F("fp32", torch::kFloat)             \
    F("fp16", torch::kHalf)              \
    F("bf16", torch::kBFloat16)          \
    F("int32", torch::kInt32)            \
    F("int8", torch::kInt8)              \
    F("uint8", torch::kUInt8)            \
    F("fp8", torch::kFloat8_e4m3fn)

inline std::string torchDTypeToStr(caffe2::TypeMeta dtype)
{
#define TYPE_CASE(type, torch_type) \
    case torch_type:                \
    {                               \
        return type;                \
    }

    switch (dtype.toScalarType())
    {
        FOREACH_BUFFER_TORCH_TYPE_MAP(TYPE_CASE);
    default:
        throw std::runtime_error("CKPyInterface: Unsupported data type " + std::to_string((int8_t)(dtype.toScalarType())));
    }

#undef TYPE_CASE
}