Commit 0a8342b8 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_shape_update

parents b31735e8 f55d7c24
......@@ -3,6 +3,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
......@@ -109,6 +110,24 @@ int main() {}
)__migraphx__";
// NOLINTNEXTLINE
const std::string math_template = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/math.hpp>
extern "C" {
__global__ void kernel(${type}* p)
{
auto x = *p;
*p = migraphx::implicit_conversion(migraphx::${invoke});
}
}
int main() {}
)__migraphx__";
migraphx::src_file make_src_file(const std::string& name, const std::string& content)
{
return {name, std::make_pair(content.data(), content.data() + content.size())};
......@@ -248,4 +267,66 @@ TEST_CASE(compile_pointwise)
EXPECT(result == output_literal.get_argument());
}
TEST_CASE(compile_math)
{
std::vector<std::string> math_invoke = {
// clang-format off
"abs(x)",
"acos(x)",
"acosh(x)",
"asin(x)",
"asinh(x)",
"atan(x)",
"atanh(x)",
"ceil(x)",
"cos(x)",
"cosh(x)",
"erf(x)",
"exp(x)",
"floor(x)",
"isnan(x)",
"log(x)",
"max(x, x)",
"min(x, x)",
"pow(x, 0)",
"pow(x, x)",
"round(x)",
"rsqrt(x)",
"sin(x)",
"sinh(x)",
"sqrt(x)",
"tan(x)",
"tanh(x)",
"where(true, x, x)",
// clang-format on
};
std::vector<std::string> data_types;
auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types())
{
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
data_types.push_back(name);
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
}
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::hip_compile_options options;
options.global = 1024;
options.local = 1024;
options.inputs = {input};
options.output = input;
migraphx::par_for(math_invoke.size() * data_types.size(), 1, [&](auto i) {
const auto& t = data_types[i % data_types.size()];
const auto& invoke = math_invoke[i / data_types.size()];
auto src = migraphx::interpolate_string(math_template, {{"type", t}, {"invoke", invoke}});
auto co = migraphx::gpu::compile_hip_code_object(src, options);
(void)co;
});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment