"src/turbomind/vscode:/vscode.git/clone" did not exist on "0632735591e5b7c34802a29e453e4ad6b7eb248f"
Unverified Commit 3e9711f0 authored by arai713's avatar arai713 Committed by GitHub
Browse files

CK Instance Gen (#1145)



* Format

* Format

* Format

* Remove const

* Use the right template

* Format

* Format

* add row/col instances

* Add missing file

* fixed

* fixing block to etile error

* Format

* Updates

* Format

* fixed rrr layout

* generating a sample JSON file: currently contains includes, prologue/epilogue and instances

* version where the json is passed into the instances to generate a key

* updated run function to just launch kernel

* updated run function: only contains kernel object, json file is updated but still needs to be cleaned up, added front-end API to parse JSON into character buffer

* adding in testing files

* cleaned up comments, still need to work on including header files

* removed unneeded files

* removed/commented out JSON implementation

* added fusion(prologue/epilogue) into instance generation

* working on instance selection

* added instance selection, need to fix instance validation

* removed block2etile map validity check for testing purposes

* test running: failing due to incorrect files/input

* all grid descs/ptrs completed, but device file not found

* Update test and embed modules

* Restore older version

* added convolution operation, written test, debugging generated code for compilation

* attempting to include CK in host directory: _Float16 error

* CK header file issues

* slight fix

* don't crash when hip can't report total memory

* dump generated code to a file

* changing sizes

* creating tensor descriptors using CK methods: set up grid desc manually, also trying to set up an argument pointer - this needs to be fixed

* some fixes to call the device code

* separating test files for conv and gemm

* completed arg ptr, now have linking errors

* clang format fix

* resolved linker issues in conv test

* remove dependency on libutility from ck

* resolved num dim error

* properly passing arg ptr, errors with passing typenames: redefinition/redeclaration

* undo the commenting of device function

* hand created kernel code to find rtc issues

* dump the full src to file

* resolved redeclaration errors, cleaned up errors for Amber's kernel code

* debugging purposes: redeclaration error

* config files

* resolved errors for NumTensor and redeclaration, formatted version.h

* resolved most errors in manually added kernel and my own. error with calling kernel object: overloaded function type

* WIP: close to getting kernel compiled

* WIP: fixing rtc errors

* fixed sequence errors, formatting, still one error with run fcn

* yay: kernel compiles and runs

* updated templated/generated version to run and compile

* minor fixes

* working generated example, resolved memory access error due to padding

* adding in reference kernel, validation failing against reference

* debugging: printing kernel argsz

* reduced error in results

* debugged reference kernel and output errors, added to generated version, currently debugging prologue function issues

* working validation (using reference convolution) with prologue function for both hard-coded and generated version

* WIP: create an alt version that creates Argument on the device

* wip: added new duplicate files, fixed fusion templating errors from working example, setting up kernel arguments

* wip: making necessary methods device code

* added grid descs, working on grid pointers, errors with stl numerics

* wip: updating kernel args - issue, replacing some std functions

* replaced std::accumulate call with temp hardcoded version

* wip: args causing memory issue

* Construct Argument object inside the kernel and use it to call convolution device function. Code runs and verification passes

* adding object file dump

* temporary hardcoding of grid size, can remove device op inst + arg ptr

* minor fix for grid size

* added modified example where arg ptr is created on the device for generated version as well

* removed device op instance and arg ptr from modified examples

* moving device op file for testing purposes and to properly build CK

* commenting out print-outs

* adjust compiler args to produce a valid ELF file

* temporary removal of validation

* reverting compiler args back for working example

* retrieve necessary arguments from generated template parameters in correct format

* calculating grid size on host-side, still need to clean up process, pass parameters to host functions properly

* scaled up factory functions/wrapper structs to implement host-side launch parameter calculations using CK host side functions - in hard-coded example

* temporary change to generate ELF format binary object file

* removed unecessary code, added comments

* formatting fix

* cleaned up code, added new tests, restructured library: move helper into CK

* refactored launch parameter calculation to be more concise

* renamed files and variables for more clarity/uniformity

* more code cleaning, removed debug statements

* moved majority of my files into codegen directory, running properly

* updated Embed.cmake(string_view) in codegen directory

* updated host directory to match Embed.cmake as well

* added old tests in

* updated instance generation methods to be more concise

* removed layout from launch parameter calculation

* working test

* fixed issue with verification, all instances working

* updated verification in other tests

* removed duplicate matrix padder file, removed code dumps

* removed old hard-coded tests

* removed old host directory, all files in codegen directory now

* fixed copyright in files

* commenting out validation

* renamed files

* made changes for review: fixed copyright, renamed files for clarity, removed comments, refactored code

* updated headers

* removing duplicate file for fwd conv to gemm, merging with original file

* fix building codegen with clang++ directly

* resolving build error from conv_fwd_to_gemm

* fix for previous error

* renaming tests

* created common test file

* cleaned up code, added comments

* renamed device op

* fixed typos in comments

* removed extra space

* code cleanup: resolving Amber's comments

* removed wrapper struct for matrix padder, fixed template

* cleaned up if statements for better readability

---------
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarJing Zhang <jizha@amd.com>
Co-authored-by: default avatarM. Amber Hassaan <amber_474@yahoo.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
parent cb138394
#pragma once
#include <algorithm>
#include <cmath>
#include <iterator>
#include <numeric>
#include <random>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
std::vector<rtc::src_file> get_headers_for_test()
{
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
});
return result;
}
template <typename V>
std::size_t GetSize(V mLens, V mStrides)
{
std::size_t space = 1;
for(std::size_t i = 0; i < mLens.Size(); ++i)
{
if(mLens[i] == 0)
continue;
space += (mLens[i] - 1) * mStrides[i];
}
return space;
}
template <class T, typename V>
rtc::buffer<T> generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
{
std::size_t space = GetSize(mLens, mStrides);
rtc::buffer<T> result(space);
std::mt19937 gen(seed);
std::uniform_real_distribution<double> dis(-1.0);
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
// std::fill(result.begin(), result.end(), 1);
return result;
}
template <class T, class U>
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
{
return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) {
return fabs(x - y) < atol + rtol * fabs(y);
});
}
std::string classify(double x)
{
switch(std::fpclassify(x))
{
case FP_INFINITE: return "inf";
case FP_NAN: return "nan";
case FP_NORMAL: return "normal";
case FP_SUBNORMAL: return "subnormal";
case FP_ZERO: return "zero";
default: return "unknown";
}
}
template <class Buffer>
void print_classification(const Buffer& x)
{
std::unordered_set<std::string> result;
for(const auto& i : x)
result.insert(classify(i));
for(const auto& c : result)
std::cout << c << ", ";
std::cout << std::endl;
}
template <class Buffer>
void print_statistics(const Buffer& x)
{
std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", ";
std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", ";
double num_elements = x.size();
auto mean =
std::accumulate(x.begin(), x.end(), double{0.0}, std::plus<double>{}) / num_elements;
auto stddev = std::sqrt(
std::accumulate(x.begin(),
x.end(),
double{0.0},
[&](double r, double v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
std::cout << "Mean: " << mean << ", ";
std::cout << "StdDev: " << stddev << "\n";
}
template <class Buffer>
void print_preview(const Buffer& x)
{
if(x.size() <= 10)
{
std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; });
}
else
{
std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; });
std::cout << "..., ";
std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; });
}
std::cout << std::endl;
}
template <class T>
struct check_all
{
rtc::buffer<T> data{};
bool operator()(const rtc::buffer<T>& x)
{
if(data.empty())
{
data = x;
return true;
}
return allclose(data, x);
}
};
template <class Solution>
auto report(const Solution& solution, bool pass)
{
return test::make_predicate(solution.ToTemplateString(), [=] { return pass; });
}
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <test.hpp> #include <test.hpp>
#include <rtc/compile_kernel.hpp> #include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp> #include <rtc/hip.hpp>
#include <fstream>
using half = _Float16; using half = _Float16;
// using half = __fp16; // using half = __fp16;
...@@ -159,7 +160,10 @@ TEST_CASE(test_problem_kernel) ...@@ -159,7 +160,10 @@ TEST_CASE(test_problem_kernel)
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1)); auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 2)); auto c = to_gpu(generate_buffer<half>(1024 * 1024, 2));
for(auto solution : prob.GetSolutions("gfx90a")) std::string epilogue = "";
std::string prologue = "";
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
{ {
auto src = ck::host::InterpolateString(gemm_compile_check, auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()}, {{"include", prob.GetIncludeHeader()},
...@@ -178,6 +182,7 @@ TEST_CASE(test_problem_kernel) ...@@ -178,6 +182,7 @@ TEST_CASE(test_problem_kernel)
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
ck::host::integer_divide_ceil(prob.N, n_per_block); ck::host::integer_divide_ceil(prob.N, n_per_block);
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data()); k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
CHECK(report(solution, check(rtc::from_gpu(c)))); CHECK(report(solution, check(rtc::from_gpu(c))));
} }
} }
......
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "ck/tensor_operation/gpu/device/helper.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include "common.hpp"
#include <fstream>
// Need this for verification
/**struct Epilogue
{
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};**/
const std::string conv_compile_check = R"__ck__(
#include <${include}>
${template};
)__ck__";
TEST_CASE(test_problem_kernel)
{
// set up problem specification
ck::host::conv::Problem_Conv_Fwd prob;
prob.NumDim = 2;
prob.G = 32;
prob.N = 256;
prob.C = 32;
prob.K = 64;
prob.Y = 3;
prob.X = 3;
prob.Hi = 28;
prob.Wi = 28;
prob.Ho = 28;
prob.Wo = 28;
check_all<ck::half_t> check;
// user provided fusion operations
std::string epilogue = R"(
struct Epilogue
{
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};
)";
std::string prologue = "";
// length+stride arrays
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.C),
static_cast<int>(prob.Hi),
static_cast<int>(prob.Wi)};
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.K),
static_cast<int>(prob.Ho),
static_cast<int>(prob.Wo)};
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.K),
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
1,
static_cast<int>(prob.Wi * prob.G * prob.C),
static_cast<int>(prob.G * prob.C)};
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
1,
static_cast<int>(prob.Wo * prob.G * prob.K),
static_cast<int>(prob.G * prob.K)};
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
static_cast<int>(prob.Y * prob.X * prob.C),
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
ck::Array<ck::index_t, 2> input_left_pads = {1, 1};
ck::Array<ck::index_t, 2> input_right_pads = {1, 1};
// move the data onto the device
auto in_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
auto wei_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
auto out_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
// CK Verficiation: Reference Kernel
/**bool pass = true;
Tensor<ck::half_t> in_host(in_lengths, in_strides);
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> out_host(out_lengths, out_strides);
std::vector<ck::index_t> conv_filter_strides_ = {2, 2};
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
std::vector<ck::index_t> input_left_pads_ = {1, 1};
std::vector<ck::index_t> input_right_pads_ = {1, 1};
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
2,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Epilogue>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_host,
wei_host,
out_host,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
auto name = solution.GetTemplateParameter<std::string>("name");
options.kernel_name = "run_" + name;
auto k = rtc::compile_kernel(srcs, options);
// Grid size calculation
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
auto tmp = get_launch_params(solution, out_lengths, out_strides);
auto grid_size = tmp * in_lengths[1];
// launch the kernel with arguments needed for the argument pointer
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
wei_dev.data(),
out_dev.data(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// auto res = rtc::from_gpu(out_dev);
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
// assert(pass);
// Simple check: this checks that the output from each instance matches the output from the
// first instance
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/helper.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
// need this for validation
/**struct Epilogue
{
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};**/
const std::string conv_compile_check = R"__ck__(
#include <${include}>
${template};
)__ck__";
TEST_CASE(test_problem_kernel)
{
// set up problem specification
ck::host::conv::Problem_Conv_Fwd prob;
prob.NumDim = 2;
prob.G = 32;
prob.N = 256;
prob.C = 32;
prob.K = 64;
prob.Y = 3;
prob.X = 3;
prob.Hi = 28;
prob.Wi = 28;
prob.Ho = 28;
prob.Wo = 28;
check_all<ck::half_t> check;
// user provided fusion operations
std::string epilogue = R"(
struct Epilogue
{
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};
)";
std::string prologue = "";
// length+stride arrays
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.C),
static_cast<int>(prob.Hi),
static_cast<int>(prob.Wi)};
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.K),
static_cast<int>(prob.Ho),
static_cast<int>(prob.Wo)};
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.K),
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
1,
static_cast<int>(prob.Wi * prob.G * prob.C),
static_cast<int>(prob.G * prob.C)};
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
1,
static_cast<int>(prob.Wo * prob.G * prob.K),
static_cast<int>(prob.G * prob.K)};
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
static_cast<int>(prob.Y * prob.X * prob.C),
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
ck::Array<ck::index_t, 2> input_left_pads = {0, 0};
ck::Array<ck::index_t, 2> input_right_pads = {0, 0};
// move the data onto the device
auto in_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
auto wei_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
auto out_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
// CK Verficiation: Reference Kernel
/**bool pass = true;
Tensor<ck::half_t> in_host(in_lengths, in_strides);
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> out_host(out_lengths, out_strides);
std::vector<ck::index_t> conv_filter_strides_ = {1, 1};
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
std::vector<ck::index_t> input_left_pads_ = {0, 0};
std::vector<ck::index_t> input_right_pads_ = {0, 0};
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
2,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Epilogue>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_host,
wei_host,
out_host,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
auto name = solution.GetTemplateParameter<std::string>("name");
options.kernel_name = "run_" + name;
auto k = rtc::compile_kernel(srcs, options);
// Grid size calculation
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
auto tmp = get_launch_params(solution, out_lengths, out_strides);
auto grid_size = tmp * in_lengths[1];
// launch the kernel with arguments needed for the argument pointer
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
wei_dev.data(),
out_dev.data(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// auto res = rtc::from_gpu(out_dev);
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
// assert(pass);
// Simple check: this checks that the output from each instance matches the output from the
// first instance
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "ck/tensor_operation/gpu/device/helper.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "common.hpp"
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
// need this for verification
/**struct Epilogue
{
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};**/
const std::string conv_compile_check = R"__ck__(
#include <${include}>
${template};
)__ck__";
TEST_CASE(test_problem_kernel)
{
// set up problem specification
ck::host::conv::Problem_Conv_Fwd prob;
prob.NumDim = 2;
prob.G = 32;
prob.N = 256;
prob.C = 32;
prob.K = 64;
prob.Y = 3;
prob.X = 3;
prob.Hi = 28;
prob.Wi = 28;
prob.Ho = 28;
prob.Wo = 28;
check_all<ck::half_t> check;
// user provided fusion operations
std::string epilogue = R"(
struct Epilogue
{
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};
)";
std::string prologue = "";
// length+stride arrays
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.C),
static_cast<int>(prob.Hi),
static_cast<int>(prob.Wi)};
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.K),
static_cast<int>(prob.Ho),
static_cast<int>(prob.Wo)};
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.K),
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
1,
static_cast<int>(prob.Wi * prob.G * prob.C),
static_cast<int>(prob.G * prob.C)};
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
1,
static_cast<int>(prob.Wo * prob.G * prob.K),
static_cast<int>(prob.G * prob.K)};
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
static_cast<int>(prob.Y * prob.X * prob.C),
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
ck::Array<ck::index_t, 2> input_left_pads = {0, 0};
ck::Array<ck::index_t, 2> input_right_pads = {0, 0};
// move the data onto the device
auto in_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
auto wei_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
auto out_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
// CK Verficiation: Reference Kernel
/**bool pass = true;
Tensor<ck::half_t> in_host(in_lengths, in_strides);
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> out_host(out_lengths, out_strides);
std::vector<ck::index_t> conv_filter_strides_ = {2, 2};
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
std::vector<ck::index_t> input_left_pads_ = {0, 0};
std::vector<ck::index_t> input_right_pads_ = {0, 0};
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
2,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Epilogue>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_host,
wei_host,
out_host,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
auto name = solution.GetTemplateParameter<std::string>("name");
options.kernel_name = "run_" + name;
auto k = rtc::compile_kernel(srcs, options);
// Grid size calculation
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
auto tmp = get_launch_params(solution, out_lengths, out_strides);
auto grid_size = tmp * in_lengths[1];
// launch the kernel with arguments needed for the argument pointer
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
wei_dev.data(),
out_dev.data(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// auto res = rtc::from_gpu(out_dev);
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
// assert(pass);
// Simple check: this checks that the output from each instance matches the output from the
// first instance
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "ck/tensor_operation/gpu/device/helper.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "common.hpp"
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
// need this for verification
/**struct Epilogue
{
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};**/
const std::string conv_compile_check = R"__ck__(
#include <${include}>
${template};
)__ck__";
TEST_CASE(test_problem_kernel)
{
// set up problem specification
ck::host::conv::Problem_Conv_Fwd prob;
prob.NumDim = 2;
prob.G = 32;
prob.N = 256;
prob.C = 32;
prob.K = 64;
prob.Y = 3;
prob.X = 3;
prob.Hi = 28;
prob.Wi = 28;
prob.Ho = 28;
prob.Wo = 28;
check_all<ck::half_t> check;
// user provided fusion operations
std::string epilogue = R"(
struct Epilogue
{
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};
)";
std::string prologue = "";
// length+stride arrays
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.C),
static_cast<int>(prob.Hi),
static_cast<int>(prob.Wi)};
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.N),
static_cast<int>(prob.K),
static_cast<int>(prob.Ho),
static_cast<int>(prob.Wo)};
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
static_cast<int>(prob.K),
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
1,
static_cast<int>(prob.Wi * prob.G * prob.C),
static_cast<int>(prob.G * prob.C)};
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
1,
static_cast<int>(prob.Wo * prob.G * prob.K),
static_cast<int>(prob.G * prob.K)};
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
static_cast<int>(prob.Y * prob.X * prob.C),
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
ck::Array<ck::index_t, 2> input_left_pads = {1, 1};
ck::Array<ck::index_t, 2> input_right_pads = {1, 1};
// move the data onto the device
auto in_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
auto wei_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
auto out_dev =
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
// CK Verficiation: Reference Kernel
/**bool pass = true;
Tensor<ck::half_t> in_host(in_lengths, in_strides);
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
Tensor<ck::half_t> out_host(out_lengths, out_strides);
std::vector<ck::index_t> conv_filter_strides_ = {1, 1};
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
std::vector<ck::index_t> input_left_pads_ = {1, 1};
std::vector<ck::index_t> input_right_pads_ = {1, 1};
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
2,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Epilogue>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_host,
wei_host,
out_host,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
auto name = solution.GetTemplateParameter<std::string>("name");
options.kernel_name = "run_" + name;
auto k = rtc::compile_kernel(srcs, options);
// Grid size calculation
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
auto tmp = get_launch_params(solution, out_lengths, out_strides);
auto grid_size = tmp * in_lengths[1];
// launch the kernel with arguments needed for the argument pointer
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
wei_dev.data(),
out_dev.data(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
// auto res = rtc::from_gpu(out_dev);
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
// assert(pass);
// Simple check: this checks that the output from each instance matches the output from the
// first instance
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -56,6 +56,8 @@ void write_string(const std::string& filename, const std::string_view& buffer) ...@@ -56,6 +56,8 @@ void write_string(const std::string& filename, const std::string_view& buffer)
} }
std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; } std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; }
// TODO: undo after extracting the codeobj
// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; }
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options) kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{ {
...@@ -89,6 +91,12 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options ...@@ -89,6 +91,12 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
auto obj = read_buffer(out_path.string()); auto obj = read_buffer(out_path.string());
std::ofstream ofh("obj.o", std::ios::binary);
for(auto i : obj)
ofh << i;
ofh.close();
// int s = std::system(("/usr/bin/cp " + out_path.string() + " codeobj.bin").c_str());
// assert(s == 0);
return kernel{obj.data(), options.kernel_name}; return kernel{obj.data(), options.kernel_name};
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <rtc/manage_ptr.hpp> #include <rtc/manage_ptr.hpp>
#include <stdexcept> #include <stdexcept>
#include <cassert> #include <cassert>
#include <iostream>
namespace rtc { namespace rtc {
...@@ -49,7 +50,10 @@ std::size_t get_available_gpu_memory() ...@@ -49,7 +50,10 @@ std::size_t get_available_gpu_memory()
size_t total; size_t total;
auto status = hipMemGetInfo(&free, &total); auto status = hipMemGetInfo(&free, &total);
if(status != hipSuccess) if(status != hipSuccess)
throw std::runtime_error("Failed getting available memory: " + hip_error(status)); {
std::cerr << "Failed getting available memory: " + hip_error(status) << std::endl;
return (8ull * 1024ull * 1024ull * 1024ull);
}
return free; return free;
} }
......
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include <fstream>
#include <variant>
// functions to return the corresponding structs based on generated template parameters
using layouts = std::variant<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::NHWGK,
ck::tensor_layout::convolution::GNDHWK,
ck::tensor_layout::convolution::NDHWGK>;
// return the layout type: currently this is the only type supported in MIOpen
auto layout_type(std::string type)
{
if(type == "ck::tensor_layout::convolution::NHWGK")
{
return ck::tensor_layout::convolution::NHWGK{};
}
throw std::runtime_error("Incorrect layout");
}
// return the right gemm spec based on the generated template parameters
ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type)
{
if(type == "ck::tensor_operation::device::GemmSpecialization::Default")
{
return ck::tensor_operation::device::GemmSpecialization::Default;
}
if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding")
{
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
}
throw std::runtime_error("Incorrect gemm spec: " + type);
}
// return the type of convolution
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type)
{
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default")
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
}
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0")
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
}
if(type ==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0")
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
}
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC")
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
}
throw std::runtime_error("Incorrect conv spec: " + type);
}
// Function to call on MatrixPadder via a wrapper struct
// NOTE: CK only uses MNKPadding for forward convolution
template <typename CDesc_MRaw_NRaw>
auto pad(ck::index_t mpb,
ck::index_t npb,
ck::index_t kpb,
ck::tensor_operation::device::GemmSpecialization gemm,
CDesc_MRaw_NRaw conv)
{
if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
ck::tensor_operation::device::MatrixPadder<
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
ck::index_t,
ck::index_t,
ck::index_t>
a;
a.MPerTile_ = mpb;
a.NPerTile_ = npb;
a.KPerTile_ = kpb;
auto tmp = grid_desc(a, conv);
return tmp;
}
throw std::runtime_error("Incorrect template parameters, check gemm spec");
}
// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
// dims
// FIXME: add a way to properly pass in the layout
auto transform_conv(ck::index_t num_dim,
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
ck::Array<ck::index_t, 5> out_lengths,
ck::Array<ck::index_t, 5> out_strides)
{
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
throw std::runtime_error("Incorrect conv spec");
}
auto transform_conv_3d(ck::index_t num_dim,
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
ck::Array<ck::index_t, 6> out_lengths,
ck::Array<ck::index_t, 6> out_strides)
{
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
throw std::runtime_error("Incorrect conv spec");
}
auto transform_conv_1d(ck::index_t num_dim,
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
ck::Array<ck::index_t, 4> out_lengths,
ck::Array<ck::index_t, 4> out_strides)
{
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
throw std::runtime_error("Incorrect dims or conv spec");
}
template <typename CGridDesc_M_N>
auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
{
if(m_per_block == 32 && n_per_block == 64)
{
auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 32 && n_per_block == 128)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 64 && n_per_block == 32)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 64 && n_per_block == 64)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 64 && n_per_block == 128)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 128 && n_per_block == 32)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 128 && n_per_block == 64)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 128 && n_per_block == 128)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 128 && n_per_block == 256)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 256 && n_per_block == 128)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
throw std::runtime_error("Incorrect template parameters");
}
// wrapper functions by dims to get grid size - uses above 3 functions
// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
auto get_launch_params_1d(ck::host::Solution solution,
ck::Array<ck::index_t, 4> out_lengths,
ck::Array<ck::index_t, 4> out_strides)
{
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides);
auto matrix_padder =
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
return b2e;
}
auto get_launch_params(ck::host::Solution solution,
ck::Array<ck::index_t, 5> out_lengths,
ck::Array<ck::index_t, 5> out_strides)
{
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides);
auto matrix_padder =
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
return b2e;
}
auto get_launch_params_3d(ck::host::Solution solution,
ck::Array<ck::index_t, 6> out_lengths,
ck::Array<ck::index_t, 6> out_strides)
{
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides);
auto matrix_padder =
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
return b2e;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename AsPointer, // tuples if multi AB, pointers if no
typename BsPointer,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
if constexpr(isMultiA || isMultiB)
{
AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp;
const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}(
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid_grp,
p_bs_grid_grp,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
}
else
{
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_batch_offset,
p_bs_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
}
#else
ignore = p_as_grid;
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
template <typename GridwiseGemm,
typename AsPointer, // tuples if multi AB, pointers if no
typename BsPointer,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
AsPointer, // tuples if multi AB, pointers if no
BsPointer,
DsPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfBatch,
HasMainKBlockLoop,
isMultiA,
isMultiB>(p_as_grid,
p_bs_grid,
p_ds_grid,
*p_e_grid,
a_element_op,
b_element_op,
cde_element_op,
batch_count,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map,
compute_ptr_offset_of_batch);
}
} // namespace
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
template <typename BLay>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i]);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
// Use appropriate gridwise gemm
using GridwiseGemm =
std::conditional_t<isMultiA || isMultiB,
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmTemplateParameters>,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
using APointers =
std::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
using BPointers =
std::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// in initializer list what is required for single const pointer).
using AGridPointer = remove_cvref_t<
decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>;
using BGridPointer = remove_cvref_t<
decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument
{
__device__ __host__ Argument(
APointers p_as,
BPointers p_bs,
const ck::Array<const void*, NumDTensor>& p_ds,
void* p_e,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
: p_as_grid_{},
p_bs_grid_{},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_batch_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
if constexpr(isMultiA || isMultiB)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmADataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if constexpr(isMultiA)
{
// p_as is tuple
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_(i) = static_cast<const DataType*>(p_as);
}
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if constexpr(isMultiB)
{
// p_bs is tuple
p_bs_grid_(i) = static_cast<const DataType*>(p_bs[i.value]);
}
else
{
// if MultiA and not MultiB then p_bs is single pointer
p_bs_grid_(i) = static_cast<const DataType*>(p_bs);
}
});
}
else
{
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
// p_as and p_bs are pointers
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
p_bs_grid_(I0) = static_cast<const BDataType*>(p_bs);
}
// populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
});
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
// populate desc for Ds/E
if constexpr(isMultiA || isMultiB)
{
const auto as_grid_desc_ak0_m_ak1 =
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 =
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
}
}
else
{
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
}
}
}
// private:
// pointers (tuple if multi AB, pointer if no)
AGridPointer p_as_grid_;
BGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
index_t num_group_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
ck::Array<index_t, NDimSpatial> conv_filter_strides_;
ck::Array<index_t, NDimSpatial> conv_filter_dilations_;
ck::Array<index_t, NDimSpatial> input_left_pads_;
ck::Array<index_t, NDimSpatial> input_right_pads_;
};
static __device__ __host__ auto MakeArgument(
APointers p_as,
BPointers p_bs,
const ck::Array<const void*, NumDTensor>& p_ds,
void* p_e,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op};
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP ...@@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP
{ {
}; };
// function to take in a struct of type MatrixPadder and call the appropriate function to get
// the output descriptor at runtime for codegen
template <GemmSpecialization GemmSpec,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType,
typename CDesc_MRaw_NRaw>
auto grid_desc(MatrixPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType> matrix_padder,
CDesc_MRaw_NRaw conv_desc)
{
auto res = matrix_padder.PadCDescriptor_M_N(conv_desc);
return res;
}
// M/N/KPerTileType could be index_t or Number<> // M/N/KPerTileType could be index_t or Number<>
template <bool PadM, template <bool PadM,
bool PadN, bool PadN,
......
...@@ -14,6 +14,17 @@ ...@@ -14,6 +14,17 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
// function to be used on device, emulates std::accumulate
template <typename T, typename ForwardIterator, typename Size>
__host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T init)
{
for(ForwardIterator x = first; x != first + count; x++)
{
init *= *x;
}
return init;
}
template <index_t NDimSpatial, device::ConvolutionForwardSpecialization ConvForwardSpecialization> template <index_t NDimSpatial, device::ConvolutionForwardSpecialization ConvForwardSpecialization>
struct TransformConvFwdToGemm struct TransformConvFwdToGemm
{ {
...@@ -607,6 +618,559 @@ struct TransformConvFwdToGemm ...@@ -607,6 +618,559 @@ struct TransformConvFwdToGemm
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
// Overloaded functions for hipRTC purposes
template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false>
__host__ __device__ static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false>
__host__ __device__ static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false>
__host__ __device__ static auto
MakeBDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX =
mult_accumulate_n<index_t>(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1);
const auto wei_gemmn_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
return wei_gemmn_gemmk_desc;
}
template <
typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::G_K_X_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_YX_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_ZYX_C> ||
is_same_v<BLayout, tensor_layout::convolution::KXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false>
__host__ __device__ static auto
MakeBDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX =
mult_accumulate_n<index_t>(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1);
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
wei_k_yx_c_desc,
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return wei_gemmn_gemmk_desc;
}
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
return out_gemmm_gemmn_desc;
}
template <
typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
return out_gemmm_gemmn_desc;
}
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride));
return out_gemmm_gemmn_desc;
}
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
// TODO: figure out aq way to properly pass in layout as an argument
struct TransformConv
{
TransformConv() {}
template <index_t NDimSpatial,
device::ConvolutionForwardSpecialization ConvForwardSpecialization>
auto
transform_func(ck::Array<index_t, NDimSpatial + 3> out_lengths,
ck::Array<index_t, NDimSpatial + 3> out_strides,
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization> conv_fwd_to_gemm)
{
if(NDimSpatial == 2)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>(out_lengths,
out_strides);
}
else if(NDimSpatial == 3)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>(out_lengths,
out_strides);
}
else if(NDimSpatial == 1)
{
return conv_fwd_to_gemm.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>(
out_lengths, out_strides);
}
}
}; };
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -36,6 +36,8 @@ struct Array ...@@ -36,6 +36,8 @@ struct Array
return *this; return *this;
} }
__host__ __device__ constexpr const TData* begin() const { return &mData[0]; }
__host__ __device__ constexpr const TData* end() const { return &mData[NSize]; }
}; };
// empty Array // empty Array
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment