"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "2906315c6384903922ca3cbff7de3aeca62c61a3"
Commit 969be85c authored by Paul's avatar Paul
Browse files

Format

parent bfad5a5b
...@@ -45,36 +45,27 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -45,36 +45,27 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[0] % 8 == 0 and b.lens()[1] % 8 == 0); b.lens()[1] % 8 == 0);
} }
struct find_ck_gemm struct find_ck_gemm
{ {
// Find a convolution followed by a pointwise operation. // Find a convolution followed by a pointwise operation.
auto matcher() const auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
{
return match::name("dot")(is_ck_gemm().bind("gemm"));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs()); mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
} }
}; };
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm{}); }
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm{});
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -77,11 +77,7 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) ...@@ -77,11 +77,7 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"; )__migraphx__";
std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }
std::size_t int_div_ceil(std::size_t x, std::size_t y)
{
return (x + y - 1) / y;
}
std::size_t get_grid_size(std::size_t m, std::size_t mpb, std::size_t n, std::size_t npb) std::size_t get_grid_size(std::size_t m, std::size_t mpb, std::size_t n, std::size_t npb)
{ {
...@@ -100,60 +96,107 @@ namespace fs = std::filesystem; ...@@ -100,60 +96,107 @@ namespace fs = std::filesystem;
struct ck_gemm_compiler : compiler<ck_gemm_compiler> struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
const std::vector<std::string> instances{ const std::vector<std::string> instances{
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8", " CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8", "PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8", "2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8", "32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8", " CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8", "PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8", "2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8", "32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8", " CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8", "PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8", "2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8", "32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8", " CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8", "PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8", "2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, "
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8"}; "32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
const std::vector<block_settings> params { "PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, "
{256, 256, 128}, "2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
{256, 256, 128}, "32, 1, 8>, 8",
{256, 128, 256}, " CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
{256, 128, 256}, "PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, "
{256, 128, 128}, "2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, "
{256, 128, 128}, "32, 1, 8>, 8",
{256, 128, 64}, " CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
{256, 128, 64}, "PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, "
{256, 64, 128}, "2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
{256, 64, 128}, "32, 1, 8>, 8",
{128, 128, 128}, " CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
{128, 128, 128}, "PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, "
{128, 128, 64}, "2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, "
{128, 128, 64}, "32, 1, 8>, 8",
{128, 64, 128}, " CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
{128, 64, 128}}; "PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, "
"32, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, "
"16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"32, 1, 4>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, "
"32, 1, 4>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, "
"16, 1, 8>, 8",
" CKDeviceGemm< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, "
"PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, "
"2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, "
"16, 1, 8>, 8"};
const std::vector<block_settings> params{{256, 256, 128},
{256, 256, 128},
{256, 128, 256},
{256, 128, 256},
{256, 128, 128},
{256, 128, 128},
{256, 128, 64},
{256, 128, 64},
{256, 64, 128},
{256, 64, 128},
{128, 128, 128},
{128, 128, 128},
{128, 128, 64},
{128, 128, 64},
{128, 64, 128},
{128, 64, 128}};
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{ {
int i = 4; int i = 4;
if (contains(v, "tuning_val")) if(contains(v, "tuning_val"))
i = v.at("tuning_val").to<int>(); i = v.at("tuning_val").to<int>();
assert(i >= 0 and i < instances.size()); assert(i >= 0 and i < instances.size());
hip_compile_options options; hip_compile_options options;
auto out_s = inputs.back(); auto out_s = inputs.back();
auto b_s = params[i]; auto b_s = params[i];
auto block_size = b_s.bs; auto block_size = b_s.bs;
auto m_per_block = b_s.mpb; auto m_per_block = b_s.mpb;
auto n_per_block = b_s.npb; auto n_per_block = b_s.npb;
auto m = out_s.lens().front(); auto m = out_s.lens().front();
auto n = out_s.lens().back(); auto n = out_s.lens().back();
auto grid_size = get_grid_size(m, m_per_block, n, n_per_block); auto grid_size = get_grid_size(m, m_per_block, n, n_per_block);
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
...@@ -161,18 +204,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -161,18 +204,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.kernel_name = "ck_gemm_kernel"; options.kernel_name = "ck_gemm_kernel";
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
auto k = inputs.front().lens().back(); auto k = inputs.front().lens().back();
auto sa = inputs.front().strides().front(); auto sa = inputs.front().strides().front();
auto sb = inputs.at(1).strides().front(); auto sb = inputs.at(1).strides().front();
auto sc = inputs.back().strides().front(); auto sc = inputs.back().strides().front();
auto src = interpolate_string(ck_gemm_kernel, {{"instance", instances[i]}, auto src = interpolate_string(ck_gemm_kernel,
{"m", to_string(m)}, {{"instance", instances[i]},
{"k", to_string(k)}, {"m", to_string(m)},
{"n", to_string(n)}, {"k", to_string(k)},
{"sa", to_string(sa)}, {"n", to_string(n)},
{"sb", to_string(sb)}, {"sa", to_string(sa)},
{"sc", to_string(sc)}}); {"sb", to_string(sb)},
{"sc", to_string(sc)}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -36,20 +36,20 @@ template <class G, class T, class U, class V, class W> ...@@ -36,20 +36,20 @@ template <class G, class T, class U, class V, class W>
__device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, W& p_t) __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, W& p_t)
{ {
constexpr G ckdg{}; constexpr G ckdg{};
using GridwiseGemm = decltype(ckdg.gridwisegemm); using GridwiseGemm = decltype(ckdg.gridwisegemm);
constexpr auto a_grid_desc_ak0_m_ak1 = ckdg.MakeAGridDescriptor_AK0_M_AK1(); constexpr auto a_grid_desc_ak0_m_ak1 = ckdg.MakeAGridDescriptor_AK0_M_AK1();
constexpr auto b_grid_desc_bk0_n_bk1 = ckdg.MakeBGridDescriptor_BK0_N_BK1(); constexpr auto b_grid_desc_bk0_n_bk1 = ckdg.MakeBGridDescriptor_BK0_N_BK1();
constexpr auto c_grid_desc_m_n = ckdg.MakeCGridDescriptor_M_N(); constexpr auto c_grid_desc_m_n = ckdg.MakeCGridDescriptor_M_N();
constexpr auto block_2_ctile_map = ckdg.MakeDefaultBlock2CTileMap(c_grid_desc_m_n); constexpr auto block_2_ctile_map = ckdg.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
// static_assert(GridwiseGemm::CheckValidity( // static_assert(GridwiseGemm::CheckValidity(
// a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map)); // a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map));
constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
constexpr auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); constexpr auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
constexpr auto a_element_op = ckdg.a_element_op; constexpr auto a_element_op = ckdg.a_element_op;
constexpr auto b_element_op = ckdg.b_element_op; constexpr auto b_element_op = ckdg.b_element_op;
constexpr auto c_element_op = ckdg.c_element_op; constexpr auto c_element_op = ckdg.c_element_op;
......
...@@ -62,7 +62,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -62,7 +62,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N> template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt struct BlockToCTileMap_M00_N0_M01Adapt
{ {
...@@ -73,8 +72,9 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -73,8 +72,9 @@ struct BlockToCTileMap_M00_N0_M01Adapt
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default; __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, __host__
ck::index_t M01 = 8) __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
ck::index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{ {
} }
...@@ -115,12 +115,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -115,12 +115,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool constexpr ValidCTileIndex(const CTileIdx& /* c_tile_idx */, __host__ __device__ bool constexpr ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const const CTileDim& /* c_tile_dim */) const
{ {
return true; // always valid provided that user gets grid size from CalculateGridSize() return true; // always valid provided that user gets grid size from CalculateGridSize()
} }
__host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const __host__ __device__ constexpr bool
CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{ {
return true; return true;
} }
...@@ -177,13 +178,11 @@ template <typename ALayout, ...@@ -177,13 +178,11 @@ template <typename ALayout,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler() ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
> struct CKDeviceGemm
struct CKDeviceGemm
{ {
//template<ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA> // template<ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA>
static constexpr auto static constexpr auto MakeAGridDescriptor_AK0_M_AK1()
MakeAGridDescriptor_AK0_M_AK1()
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(ck::is_same_v<ck::tensor_layout::gemm::RowMajor, ALayout>) if constexpr(ck::is_same_v<ck::tensor_layout::gemm::RowMajor, ALayout>)
...@@ -287,9 +286,8 @@ struct CKDeviceGemm ...@@ -287,9 +286,8 @@ struct CKDeviceGemm
} }
} }
//template<ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB> // template<ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB>
static constexpr auto static constexpr auto MakeBGridDescriptor_BK0_N_BK1()
MakeBGridDescriptor_BK0_N_BK1()
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -393,9 +391,8 @@ struct CKDeviceGemm ...@@ -393,9 +391,8 @@ struct CKDeviceGemm
} }
} }
//template<ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC> // template<ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC>
static constexpr auto static constexpr auto MakeCGridDescriptor_M_N()
MakeCGridDescriptor_M_N()
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value)
...@@ -463,7 +460,7 @@ struct CKDeviceGemm ...@@ -463,7 +460,7 @@ struct CKDeviceGemm
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1()); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1());
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N()); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -515,7 +512,7 @@ struct CKDeviceGemm ...@@ -515,7 +512,7 @@ struct CKDeviceGemm
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
GridwiseGemm gridwisegemm{}; GridwiseGemm gridwisegemm{};
AElementwiseOperation a_element_op{}; AElementwiseOperation a_element_op{};
BElementwiseOperation b_element_op{}; BElementwiseOperation b_element_op{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment