amd_wmma.hpp 3.2 KB
Newer Older
aska-0096's avatar
aska-0096 committed
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP

#include "data_type.hpp"
aska-0096's avatar
aska-0096 committed
8
// TODO: Add arch limitation
aska-0096's avatar
aska-0096 committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
namespace ck {

// wave32 only
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;

template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
        reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
            reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
    }
};

// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;

template <>
struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
        reg_c.template AsType<float8_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
                reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
    }
};

// src: fp16, dst: fp16
44
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
aska-0096's avatar
aska-0096 committed
45
46
struct intrin_wmma_f16_16x16x16_f16_w32;

47
48
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
aska-0096's avatar
aska-0096 committed
49
50
{
    template <class FloatC>
51
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
aska-0096's avatar
aska-0096 committed
52
53
54
55
56
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
        reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
57
            reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
aska-0096's avatar
aska-0096 committed
58
59
60
    }
};

61
62
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
aska-0096's avatar
aska-0096 committed
63
64
struct intrin_wmma_bf16_16x16x16_bf16_w32;

65
66
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
aska-0096's avatar
aska-0096 committed
67
68
{
    template <class FloatC>
69
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
aska-0096's avatar
aska-0096 committed
70
71
72
73
74
75
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
        reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
76
                reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
aska-0096's avatar
aska-0096 committed
77
78
79
80
    }
};

// src: iu8, dst: i32
81
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
aska-0096's avatar
aska-0096 committed
82
83
struct intrin_wmma_i32_16x16x16_iu8_w32;

84
85
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
aska-0096's avatar
aska-0096 committed
86
87
{
    template <class FloatC>
88
    __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
aska-0096's avatar
aska-0096 committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    {
        reg_c.template AsType<int32x8_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
                neg_a,
                bit_cast<int32x4_t>(reg_a),
                neg_b,
                bit_cast<int32x4_t>(reg_b),
                reg_c.template AsType<int32x8_t>()[Number<0>{}],
                clamp);
    }
};

} // namespace ck
#endif