main.cpp 3.47 KB
Newer Older
arai713's avatar
arai713 committed
1
2
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
Paul Fultz II's avatar
Paul Fultz II committed
3
4
5
6
7
8
9

#include <functional>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
#include "ck/host/device_gemm_multiple_d/operation.hpp"
arai713's avatar
arai713 committed
10
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
Paul Fultz II's avatar
Paul Fultz II committed
11
12
13
14
15
16
#include "ck/host/stringutils.hpp"

using ck::host::Transform;

struct Emitters
{
arai713's avatar
arai713 committed
17
    // retrieve the hard-coded instances provided, template them, and then store them in a map
Paul Fultz II's avatar
Paul Fultz II committed
18
19
20
    std::unordered_map<std::string, std::function<std::vector<std::string>()>> m;

    template <class T>
arai713's avatar
arai713 committed
21
    void Register(const std::string& name, const std::string& prologue, const std::string& epilogue)
Paul Fultz II's avatar
Paul Fultz II committed
22
    {
arai713's avatar
arai713 committed
23
24
        m[name] = [&] {
            auto configs = T::CreateOperations(prologue, epilogue);
Paul Fultz II's avatar
Paul Fultz II committed
25
26
27
28
29

            return Transform(configs, [](const auto& ops) { return ToTuple(ops); });
        };
    }

arai713's avatar
arai713 committed
30
    // takes in an operation instance and uses it to substitute the correct values into the template
Paul Fultz II's avatar
Paul Fultz II committed
31
32
33
34
35
36
37
38
    template <class T>
    static std::string ToTuple(const T& ops)
    {
        auto templates = Transform(
            ops, [](const auto& op) { return "    " + op.ToSolution().ToTemplateString(); });
        return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">";
    }

arai713's avatar
arai713 committed
39
    // Join together all the strings in the map
Paul Fultz II's avatar
Paul Fultz II committed
40
41
42
43
44
45
46
47
48
49
50
51
    std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); }

    std::vector<std::string> List() const
    {
        return Transform(m, [](auto&& p) { return p.first; });
    }
};

int main(int argc, const char* argv[])
{
    std::string prog = argv[0];
    std::vector<std::string> args(argv + 1, argv + argc);
arai713's avatar
arai713 committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    // Specify problem type and problem size
    ck::host::device_gemm_multiple_d::Problem prob;
    prob.M = 1024;
    prob.N = 1024;
    prob.K = 1024;

    // user provided fusion
    std::string prologue = "";
    std::string epilogue = R"(
struct Epilogue
{
    __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};

    template <typename E, typename D>
    __host__ __device__ constexpr void operator()(E& e, const D& d) const;

    template <>
    __host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
                                                                          const ck::half_t& d) const
    {
        e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
    }

    float alpha_;
    float beta_;
};)";

    // Load in operations into the Register
Paul Fultz II's avatar
Paul Fultz II committed
81
82
    Emitters e;
    e.Register<ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle>(
arai713's avatar
arai713 committed
83
        "DeviceGemmMultipleD_Xdl_CShuffle", prologue, epilogue);
Paul Fultz II's avatar
Paul Fultz II committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) {
           return arg == "-h" or arg == "--help";
       }))
    {
        std::cout << "USAGE:" << std::endl;
        std::cout << "    " << prog << " [TEMPLATE]" << std::endl;
        std::cout << std::endl;
        std::cout << "FLAGS:" << std::endl;
        std::cout << "    -h, --help                     Show help" << std::endl;
        std::cout << std::endl;
        std::cout << "TEMPLATES:" << std::endl;
        for(auto x : e.List())
            std::cout << "    " << x << std::endl;
        std::cout << std::endl;
        return 0;
    }

arai713's avatar
arai713 committed
102
    // print out all the instances for the operation that was chosen at the command line
Paul Fultz II's avatar
Paul Fultz II committed
103
104
105
106
107
    for(auto name : args)
        std::cout << e.Emit(name) << std::endl;

    return 0;
}