"sims/net/vscode:/vscode.git/clone" did not exist on "1b258eee5c6f27f7553e9afdb92302d0911d06c6"
profile_transpose.cpp 2.56 KB
Newer Older
arai713's avatar
arai713 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "profiler/profile_transpose_impl.hpp"
#include "profiler_operation_registry.hpp"

enum struct DataType
{
    F32_F32_F32_F32_F32, // 0
    F16_F16_F16_F16_F16, // 1
};

#define OP_NAME "transpose"
#define OP_DESC "Transpose"

static void print_helper_msg()
{
    printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
    printf("arg2: data type (0: fp32; 1: fp16)\n");
    printf("arg3: verification (0: no; 1: yes)\n");
    printf("arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n");
    printf("arg5: print tensor value (0: no; 1: yes)\n");
    printf("arg6: time kernel (0=no, 1=yes)\n");
29
    printf("arg7 to arg11: N, C, D, H, W\n");
arai713's avatar
arai713 committed
30
31
32
33
}

int profile_transpose(int argc, char* argv[])
{
34
    if(argc != 12)
arai713's avatar
arai713 committed
35
36
37
38
39
    {
        print_helper_msg();
        exit(1);
    }

40
41
42
43
44
45
46
47
48
49
    const auto data_type                   = static_cast<DataType>(std::stoi(argv[2]));
    const bool do_verification             = std::stoi(argv[3]);
    const int init_method                  = std::stoi(argv[4]);
    const bool do_log                      = std::stoi(argv[5]);
    const bool time_kernel                 = std::stoi(argv[6]);
    const std::vector<ck::index_t> lengths = {std::stoi(argv[7]),
                                              std::stoi(argv[8]),
                                              std::stoi(argv[9]),
                                              std::stoi(argv[10]),
                                              std::stoi(argv[11])};
arai713's avatar
arai713 committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    using F32 = float;
    using F16 = ck::half_t;

    auto profile = [&](auto a_type, auto b_type) {
        using ADataType              = decltype(a_type);
        using BDataType              = decltype(b_type);
        constexpr ck::index_t NumDim = 5;

        bool pass = ck::profiler::profile_transpose_impl<ADataType, BDataType, NumDim>(
            do_verification, init_method, do_log, time_kernel, lengths);

        return pass ? 0 : 1;
    };

    if(data_type == DataType::F32_F32_F32_F32_F32)
    {
        return profile(F32{}, F32{});
    }
    else if(data_type == DataType::F16_F16_F16_F16_F16)
    {
        return profile(F16{}, F16{});
    }
    else
    {
        std::cout << "this data_type & layout is not implemented" << std::endl;

        return 1;
    }
}

REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_transpose);