Commit 90db83e7 authored by Paul's avatar Paul
Browse files

Format

parent 8e20f747
......@@ -9,7 +9,7 @@
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
#include "ck/host/types.hpp"
namespace ck {
namespace host {
......@@ -31,6 +31,10 @@ struct Problem
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
......
......@@ -28,6 +28,10 @@ struct Problem
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string CDEElementOp = "ck::Tuple<>";
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
......
......@@ -4,11 +4,14 @@
#pragma once
#include <cstdint>
#include <unordered_set>
namespace ck {
namespace host {
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
const std::unordered_set<std::string>& get_xdlop_archs();
} // namespace host
} // namespace ck
......@@ -11,5 +11,11 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
return (x + y - std::size_t{1}) / y;
}
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"};
return supported_archs;
}
} // namespace host
} // namespace ck
......@@ -7,15 +7,6 @@
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
const std::string compile_check = R"__ck__(
#include <${include}>
extern "C" __global__ void f() {
using type = ${template}::DeviceOp;
}
)__ck__";
std::vector<rtc::src_file> get_headers_for_test()
{
std::vector<rtc::src_file> result;
......@@ -29,20 +20,41 @@ std::vector<rtc::src_file> get_headers_for_test()
return result;
}
TEST_CASE(test_operation)
const std::string gemm_compile_check = R"__ck__(
#include <${include}>
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, const ck::half_t* c) {
using G = ${template};
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${n, ${k})),
ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m, ${n})));
static_assert(desc.IsValid(), "Invalid ck gemm.");
${template}::Run(desc,
a,
b,
ck::make_tuple(),
c);
}
)__ck__";
TEST_CASE(test_problem_kernel)
{
ck::host::device_gemm_multiple_d::Problem prob;
prob.M = 256;
prob.N = 256;
prob.K = 256;
auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(prob);
for(auto op : ops)
for(auto solution : prob.GetSolutions("gfx90a"))
{
auto solution = op.ToSolution();
std::string include =
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
auto src = ck::host::InterpolateString(
compile_check, {{"include", include}, {"template", solution.ToTemplateString()}});
auto src = ck::host::InterpolateString(compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment