common.hpp 2.3 KB
Newer Older
Po-Yen, Chen's avatar
Po-Yen, Chen committed
1
2
3
4
5
6
7
8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <cstddef>
#include <cstdlib>
#include <iostream>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
9
10
#include <iterator>
#include <numeric>
Po-Yen, Chen's avatar
Po-Yen, Chen committed
11
12
13
14
15
16
17
18
19
20

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

21
22
using F16 = ck::half_t;

Po-Yen, Chen's avatar
Po-Yen, Chen committed
23
24
25
26
27
28
29
30
31
32
33
34
struct ExecutionConfig final
{
    bool do_verification = true;
    bool time_kernel     = false;
};

struct Problem final
{
    std::array<std::size_t, 4> shape = {4, 16, 32, 32};
    std::array<std::size_t, 4> axes  = {0, 2, 3, 1};
};

35
36
37
38
39
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

Po-Yen, Chen's avatar
Po-Yen, Chen committed
40
41
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    constexpr int num_execution_config_args = 2;
    constexpr int num_problem_args          = 8;

    assert(num_problem_args == problem.shape.size() + problem.axes.size());

    if(argc == 1)
    {
        // use default case
    }
    else if(argc == 1 + num_execution_config_args)
    {
        config.do_verification = std::stoi(argv[1]);
        config.time_kernel     = std::stoi(argv[2]);
    }
    else if(argc == 1 + num_execution_config_args + num_problem_args)
    {
        config.do_verification = std::stoi(argv[1]);
        config.time_kernel     = std::stoi(argv[2]);

        // read shape
        for(std::size_t idx = 0; idx < problem.shape.size(); ++idx)
        {
            problem.shape[idx] = std::stoi(argv[idx + 3]);
        }

        // read axes
        for(std::size_t idx = 0; idx < problem.axes.size(); ++idx)
        {
            problem.axes[idx] = std::stoi(argv[idx + problem.shape.size() + 3]);
        }
    }
    else
    {
        std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
                  << "arg2: time kernel (0=no, 1=yes)" << std::endl
                  << "arg3 ~ arg6: shape for 4D tensor" << std::endl
                  << "arg7 ~ arg10: axes to permute" << std::endl;
        return false;
    }

Po-Yen, Chen's avatar
Po-Yen, Chen committed
82
83
    return true;
}