amd_gemm_dpp.hpp 2.3 KB
Newer Older
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
8
#include "ck/utility/inner_product_dpp8.hpp"
9
10
11
12
13

namespace ck {

namespace dpp8 {

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
template <class ABDataType>
struct dpp_datatypes;

template <>
struct dpp_datatypes<half_t>
{
    // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
    // single instruction.
    using a_dtype                        = half_t;
    using b_dtype                        = half_t;
    using c_dtype                        = float;
    static constexpr index_t k_per_instr = 2;
};

template <index_t MPerThread,
          index_t NPerThread,
          index_t KPerThread,
          class BaseInputType,
          class AVecDataType,
          class BVecDataType,
          class CVecDataType,
35
          bool ShareA>
36
struct DppInstrRunner
37
{
38
39
40
41
    using datatypes_conf = dpp_datatypes<BaseInputType>;
    using ADataType      = typename datatypes_conf::a_dtype;
    using BDataType      = typename datatypes_conf::b_dtype;
    using CDataType      = typename datatypes_conf::c_dtype;
42
43
44
45
46
47
48
49
50
51
52
53

    __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec)
    {
        constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread;

        const vector_type<ADataType, KPerThread> a_vector{a_vec};
        const vector_type<BDataType, KPerThread> b_vector{b_vec};

        static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) {
            float c = c_vec.template AsType<CDataType>()(c_idx);
            // Next `c_idx` implies that we need to pull data from the next lane.
            constexpr index_t source_lane = c_idx;
54
            static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) {
55
56
57
58
59
60
61
                const auto a_k_vec = a_vector.template AsType<AVecDataType>()[k_chunk];
                const auto b_k_vec = b_vector.template AsType<BVecDataType>()[k_chunk];
                ck::dpp8::
                    inner_product_dpp<AVecDataType, BVecDataType, CDataType, source_lane, ShareA>(
                        a_k_vec, b_k_vec, c);
            });
            c_vec.template AsType<CDataType>()(c_idx) = c;
62
        });
63
64
    }
};
65
66
67
68

} // namespace dpp8

} // namespace ck