profile_softmax_impl.hpp 8.78 KB
Newer Older
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

6
#include <algorithm>
7
#include <iomanip>
8
9
10
#include <iostream>
#include <string>
#include <vector>
11
12
13

#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
14
#include "ck/library/utility/device_memory.hpp"
15
#include "ck/library/utility/fill.hpp"
16
#include "ck/library/utility/host_tensor.hpp"
17
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
18
#include "ck/library/tensor_operation_instance/gpu/softmax.hpp"
Adam Osewski's avatar
Adam Osewski committed
19
20
21
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
22
23
24
25

namespace ck {
namespace profiler {

26
enum struct SoftmaxDataType
27
28
29
30
31
32
33
34
{
    F32_F32, // in, out
    F16_F16,
    BF16_BF16,
    INT8_INT8,
};

// clang-format off
35
template <typename SoftmaxDataType> std::string type_to_string();
36
37
38
39
40
41
42
template <> std::string type_to_string<float>()   { return "f32"; }
template <> std::string type_to_string<half_t>()  { return "f16"; }
template <> std::string type_to_string<bhalf_t>() { return "bf16"; }
template <> std::string type_to_string<int8_t>()  { return "int8"; }
template <> std::string type_to_string<int32_t>() { return "int32"; }
// clang-format on

Adam Osewski's avatar
Adam Osewski committed
43
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
44
bool profile_softmax_impl(int do_verification,
45
46
47
48
49
50
51
                          int init_method,
                          bool do_log,
                          bool time_kernel,
                          std::vector<index_t> in_length,
                          std::vector<index_t> in_strides,
                          std::vector<index_t> reduce_dims,
                          AccDataType alpha,
52
                          AccDataType beta)
53
{
Adam Osewski's avatar
Adam Osewski committed
54
55
56
57
58
    if(Rank != in_length.size())
    {
        throw std::runtime_error("Input tensor rank is different from template argument Rank!");
    }

59
60
61
    Tensor<InDataType> in = in_strides.empty() ? Tensor<InDataType>(in_length)
                                               : Tensor<InDataType>(in_length, in_strides);
    Tensor<OutDataType> out(in.mDesc);
62
    Tensor<OutDataType> prior_out(in.mDesc);
63
64
65

    switch(init_method)
    {
66
    case 0: break;
67
    case 1:
68
69
70
        ck::utils::FillUniformDistributionIntegerValue<InDataType>{-5.f, 5.f}(in.begin(), in.end());
        ck::utils::FillUniformDistributionIntegerValue<OutDataType>{-5.f, 5.f}(prior_out.begin(),
                                                                               prior_out.end());
71
72
        break;
    default:
73
74
        ck::utils::FillUniformDistribution<InDataType>{0.0f, 1.0f}(in);
        ck::utils::FillUniformDistribution<OutDataType>{-0.5f, 0.5f}(prior_out);
75
76
    }

77
78
79
80
81
82
83
84
    Tensor<OutDataType> out_ref(prior_out);

    if(do_verification)
    {
        using ReferenceSoftmax =
            tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
        ReferenceSoftmax{}.MakeInvoker().Run({in, out_ref, alpha, beta, reduce_dims});
    }
85

86
87
88
    DeviceMem in_dev(in.GetElementSpaceSizeInBytes());
    DeviceMem out_dev(out.GetElementSpaceSizeInBytes());
    in_dev.ToDevice(in.data());
89

90
91
    std::vector<index_t> in_tensor_lengths(in.GetLengths().begin(), in.GetLengths().end());
    std::vector<index_t> in_tensor_strides(in.GetStrides().begin(), in.GetStrides().end());
92

Adam Osewski's avatar
Adam Osewski committed
93
94
    // add device softmax instances
    using PassThrough = ck::tensor_operation::element_wise::PassThrough;
95
96
    using DeviceOp    = tensor_operation::device::
        DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
97

98
99
100
101
    // get device op instances
    const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory<
        DeviceOp>::GetInstances();
    std::cout << "found " << instances.size() << " instances" << std::endl;
102
103
104
105
106
107
108
109
110

    if(instances.size() <= 0)
    {
        throw std::runtime_error("wrong! no device normalization instance found");
    }

    std::string best_instance_name;
    float best_avg_time   = std::numeric_limits<float>::max();
    float best_gb_per_sec = 0;
111
    std::vector<bool> instance_pass;
Adam Osewski's avatar
Adam Osewski committed
112

113
114
115
116
    for(auto& inst_ptr : instances)
    {
        // Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3
        // problem to rank 4 kernel) other than invoking IsSupportedArgument()?
117
        if(!(inst_ptr->GetNumReduceDim() == static_cast<index_t>(reduce_dims.size())))
118
119
120
121
        {
            continue;
        }

122
123
        auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths,
                                                          in_tensor_strides,
124
125
126
127
                                                          reduce_dims,
                                                          &alpha,
                                                          &beta,
                                                          in_dev.GetDeviceBuffer(),
Adam Osewski's avatar
Adam Osewski committed
128
129
130
                                                          out_dev.GetDeviceBuffer(),
                                                          PassThrough{},
                                                          PassThrough{});
131
132
133
134
135
136

        if(!inst_ptr->IsSupportedArgument(argument_ptr.get()))
        {
            std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
            LogRange(std::cout << "input lengths = [", in_length, ", ")
                << "], "
137
138
139
140
                << "scaler = [" << alpha << ", " << beta << "]";
            LogRange(std::cout << ", reduce dims = [", reduce_dims, ", ") << "]." << std::endl;
            instance_pass.push_back(true);
            continue;
141
142
        }

143
        out_dev.ToDevice(prior_out.data());
144
        auto invoker_ptr = inst_ptr->MakeInvokerPointer();
145
        float avg_time   = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
146

147
148
149
150
151
152
        if(time_kernel)
        {
            std::size_t num_bytes =
                in.GetElementSize() * sizeof(InDataType) +
                (beta == 0.0f ? 1 : 2) * out.GetElementSize() * sizeof(OutDataType);
            float gb_per_sec = num_bytes / 1.E6 / avg_time;
153

154
155
            std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
                      << inst_ptr->GetTypeString() << std::endl;
156

157
158
159
160
161
162
            if(avg_time < best_avg_time)
            {
                best_instance_name = inst_ptr->GetTypeString();
                best_avg_time      = avg_time;
                best_gb_per_sec    = gb_per_sec;
            }
163
164
165
166
        }

        if(do_verification)
        {
167
168
            out_dev.FromDevice(out.data());
            bool pass = true;
169
170
            if(std::is_same<InDataType, int8_t>::value)
            {
171
172
                pass = pass && ck::utils::check_err(
                                   out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1);
173
174
175
176
177
178
179
180
181
182
                if(do_log)
                {
                    LogRangeAsType<int>(std::cout << "in  : ", in.mData, ",") << std::endl;
                    LogRangeAsType<int>(std::cout << "out_ref  : ", out_ref.mData, ",")
                        << std::endl;
                    LogRangeAsType<int>(std::cout << "out  : ", out.mData, ",") << std::endl;
                }
            }
            else
            {
183
                pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "in  : ", in.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "out_ref  : ", out_ref.mData, ",")
                        << std::endl;
                    LogRangeAsType<float>(std::cout << "out  : ", out.mData, ",") << std::endl;
                }
            }

            if(!pass)
            {
                std::cout << inst_ptr->GetTypeString() << " failed verification: ";
                LogRange(std::cout << "input lengths = [", in_length, ", ")
                    << "], "
                    << "scaler = [" << alpha << ", " << beta << "]." << std::endl;
            }
200
            instance_pass.push_back(pass);
201
202
        }
    }
203
204
205
206
207
208
209
210
211
212
213
214
215
    if(time_kernel)
    {
        std::cout << "Best Perf for datatype = " << type_to_string<InDataType>() << "_"
                  << type_to_string<OutDataType>() << ", ";
        LogRange(std::cout << "length = ", in_tensor_lengths, ",") << ", ";
        LogRange(std::cout << "stride = ", in_tensor_strides, ",") << ", ";
        LogRange(std::cout << "reduce dims ", reduce_dims, ",") << ", ";
        std::cout << "alpha = " << alpha << ", "
                  << "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec
                  << " GB/s, " << best_instance_name << std::endl;
    }
    return std::all_of(
        std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; });
216
217
218
219
}

} // namespace profiler
} // namespace ck