Commit d5e056c7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add description for profiler operations

parent 8116d2b3
...@@ -166,4 +166,4 @@ int profile_softmax(int argc, char* argv[]) ...@@ -166,4 +166,4 @@ int profile_softmax(int argc, char* argv[])
// return 0; // return 0;
// } // }
REGISTER_PROFILER_OPERATION("softmax", profile_softmax) REGISTER_PROFILER_OPERATION("softmax", "Softmax", profile_softmax);
...@@ -8,27 +8,7 @@ ...@@ -8,27 +8,7 @@
static void print_helper_message() static void print_helper_message()
{ {
// clang-format off std::cout << "arg1: tensor operation " << ProfilerOperationRegistry::GetInstance() << std::endl;
printf("arg1: tensor operation (gemm: GEMM\n"
" gemm_splitk: Split-K GEMM\n"
" gemm_bilinear: GEMM+Bilinear\n"
" gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n"
" gemm_reduce: GEMM+Reduce\n"
" gemm_bias_add_reduce: GEMM+Bias+Add+Reduce\n"
" batched_gemm: Batched GEMM\n"
" batched_gemm_gemm: Batched+GEMM+GEMM\n"
" batched_gemm_add_relu_gemm_add: Batched+GEMM+bias+gelu+GEMM+bias\n"
" batched_gemm_reduce: Batched GEMM+Reduce\n"
" grouped_gemm: Grouped GEMM\n"
" conv_fwd: Convolution Forward\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv_bwd_data: Convolution Backward Data\n"
" grouped_conv_fwd: Grouped Convolution Forward\n"
" grouped_conv_bwd_weight: Grouped Convolution Backward Weight\n"
" softmax: Softmax\n"
" reduce: Reduce\n");
// clang-format on
} }
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -37,10 +17,13 @@ int main(int argc, char* argv[]) ...@@ -37,10 +17,13 @@ int main(int argc, char* argv[])
{ {
print_helper_message(); print_helper_message();
} }
else if(auto operation = ProfilerOperationRegistry::GetInstance().Get(argv[1]); operation.has_value()) else if(const auto operation = ProfilerOperationRegistry::GetInstance().Get(argv[1]);
operation.has_value())
{ {
return (*operation)(argc, argv); return (*operation)(argc, argv);
} else { }
else
{
std::cerr << "cannot find operation: " << argv[1] << std::endl; std::cerr << "cannot find operation: " << argv[1] << std::endl;
return EXIT_FAILURE; return EXIT_FAILURE;
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional> #include <functional>
#include <iterator>
#include <map> #include <map>
#include <optional> #include <optional>
#include <string_view> #include <string_view>
...@@ -14,7 +15,18 @@ class ProfilerOperationRegistry final ...@@ -14,7 +15,18 @@ class ProfilerOperationRegistry final
using Operation = std::function<int(int, char*[])>; using Operation = std::function<int(int, char*[])>;
private: private:
std::unordered_map<std::string_view, Operation> operations_; struct Entry final
{
explicit Entry(std::string_view description, Operation operation) noexcept
: description_(description), operation_(operation)
{
}
std::string_view description_;
Operation operation_;
};
std::map<std::string_view, Entry> entries_;
public: public:
static ProfilerOperationRegistry& GetInstance() static ProfilerOperationRegistry& GetInstance()
...@@ -25,22 +37,37 @@ class ProfilerOperationRegistry final ...@@ -25,22 +37,37 @@ class ProfilerOperationRegistry final
std::optional<Operation> Get(std::string_view name) const std::optional<Operation> Get(std::string_view name) const
{ {
const auto found = operations_.find(name); const auto found = entries_.find(name);
if(found == end(operations_)) if(found == end(entries_))
{ {
return std::nullopt; return std::nullopt;
} }
return found->second; return (found->second).operation_;
} }
bool Add(std::string_view name, Operation operation) bool Add(std::string_view name, std::string_view description, Operation operation)
{ {
return operations_.try_emplace(name, std::move(operation)).second; return entries_
.emplace(std::piecewise_construct,
std::forward_as_tuple(name),
std::forward_as_tuple(description, std::move(operation)))
.second;
} }
};
#define REGISTER_PROFILER_OPERATION(name, operation) \ friend std::ostream& operator<<(std::ostream& stream, const ProfilerOperationRegistry& registry)
namespace { \ {
const bool result = ::ProfilerOperationRegistry::GetInstance().Add(name, operation); \ stream << "{\n";
for(auto& [name, entry] : registry.entries_)
{
stream << "\t" << name << ": " << entry.description_ << "\n";
}
stream << "}";
return stream;
} }
};
#define REGISTER_PROFILER_OPERATION(name, description, operation) \
static const bool result = \
::ProfilerOperationRegistry::GetInstance().Add(name, description, operation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment