main.cpp 3.36 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
6
7

#include <functional>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
#include "ck/host/device_gemm_multiple_d/operation.hpp"
arai713's avatar
arai713 committed
8
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
Paul Fultz II's avatar
Paul Fultz II committed
9
10
11
12
13
14
#include "ck/host/stringutils.hpp"

using ck::host::Transform;

struct Emitters
{
arai713's avatar
arai713 committed
15
    // retrieve the hard-coded instances provided, template them, and then store them in a map
Paul Fultz II's avatar
Paul Fultz II committed
16
17
18
    std::unordered_map<std::string, std::function<std::vector<std::string>()>> m;

    template <class T>
arai713's avatar
arai713 committed
19
    void Register(const std::string& name, const std::string& prologue, const std::string& epilogue)
Paul Fultz II's avatar
Paul Fultz II committed
20
    {
arai713's avatar
arai713 committed
21
22
        m[name] = [&] {
            auto configs = T::CreateOperations(prologue, epilogue);
Paul Fultz II's avatar
Paul Fultz II committed
23
24
25
26
27

            return Transform(configs, [](const auto& ops) { return ToTuple(ops); });
        };
    }

arai713's avatar
arai713 committed
28
    // takes in an operation instance and uses it to substitute the correct values into the template
Paul Fultz II's avatar
Paul Fultz II committed
29
30
31
32
33
34
35
36
    template <class T>
    static std::string ToTuple(const T& ops)
    {
        auto templates = Transform(
            ops, [](const auto& op) { return "    " + op.ToSolution().ToTemplateString(); });
        return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">";
    }

arai713's avatar
arai713 committed
37
    // Join together all the strings in the map
Paul Fultz II's avatar
Paul Fultz II committed
38
39
40
41
42
43
44
45
46
47
48
49
    std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); }

    std::vector<std::string> List() const
    {
        return Transform(m, [](auto&& p) { return p.first; });
    }
};

int main(int argc, const char* argv[])
{
    std::string prog = argv[0];
    std::vector<std::string> args(argv + 1, argv + argc);
arai713's avatar
arai713 committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    // Specify problem type and problem size
    ck::host::device_gemm_multiple_d::Problem prob;
    prob.M = 1024;
    prob.N = 1024;
    prob.K = 1024;

    // user provided fusion
    std::string prologue = "";
    std::string epilogue = R"(
struct Epilogue
{
    __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};

    template <typename E, typename D>
    __host__ __device__ constexpr void operator()(E& e, const D& d) const;

    template <>
    __host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
                                                                          const ck::half_t& d) const
    {
        e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
    }

    float alpha_;
    float beta_;
};)";

    // Load in operations into the Register
Paul Fultz II's avatar
Paul Fultz II committed
79
80
    Emitters e;
    e.Register<ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle>(
arai713's avatar
arai713 committed
81
        "DeviceGemmMultipleD_Xdl_CShuffle", prologue, epilogue);
Paul Fultz II's avatar
Paul Fultz II committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) {
           return arg == "-h" or arg == "--help";
       }))
    {
        std::cout << "USAGE:" << std::endl;
        std::cout << "    " << prog << " [TEMPLATE]" << std::endl;
        std::cout << std::endl;
        std::cout << "FLAGS:" << std::endl;
        std::cout << "    -h, --help                     Show help" << std::endl;
        std::cout << std::endl;
        std::cout << "TEMPLATES:" << std::endl;
        for(auto x : e.List())
            std::cout << "    " << x << std::endl;
        std::cout << std::endl;
        return 0;
    }

arai713's avatar
arai713 committed
100
    // print out all the instances for the operation that was chosen at the command line
Paul Fultz II's avatar
Paul Fultz II committed
101
102
103
104
105
    for(auto name : args)
        std::cout << e.Emit(name) << std::endl;

    return 0;
}