Commit 028bb4b6 authored by Paul's avatar Paul
Browse files

Use the right template

parent e5080ac5
...@@ -26,9 +26,9 @@ const std::string gemm_compile_check = R"__ck__( ...@@ -26,9 +26,9 @@ const std::string gemm_compile_check = R"__ck__(
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) { extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
using G = ${template}; using G = ${template};
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${n, ${k})), ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${n}, ${k})),
ck::make_tuple(), ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m, ${n}))); ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
static_assert(desc.IsValid(), "Invalid ck gemm."); static_assert(desc.IsValid(), "Invalid ck gemm.");
...@@ -49,7 +49,7 @@ TEST_CASE(test_problem_kernel) ...@@ -49,7 +49,7 @@ TEST_CASE(test_problem_kernel)
prob.K = 256; prob.K = 256;
for(auto solution : prob.GetSolutions("gfx90a")) for(auto solution : prob.GetSolutions("gfx90a"))
{ {
auto src = ck::host::InterpolateString(compile_check, auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()}, {{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()}, {"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)}, {"m", std::to_string(prob.M)},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment