Commit 5ed0cbe4 authored by Paul's avatar Paul
Browse files

Add gemm_batcher

parent 07b8f71c
......@@ -67,7 +67,7 @@ extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<CK_DeviceGemmMultipleD<${instance}>>(xs...);
ck_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
});
}
......@@ -230,9 +230,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config);
hip_compile_options options;
auto block_size = ip.get_block_size();
auto grid_size = ip.get_grid_size(config);
auto grid_size = blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
......@@ -246,6 +248,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{{"instance", ip.str()},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
......
......@@ -30,7 +30,7 @@
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_gemm_includes.hpp>
#include <migraphx/kernels/print.hpp>
#include <migraphx/kernels/gemm_batcher.hpp>
namespace migraphx {
......@@ -46,7 +46,7 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm(E e, A a, B b, Ds... ds)
__device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
{
constexpr const G gemm{};
......@@ -95,5 +95,13 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds)
block_2_etile_map);
}
template <class G, index_int BlocksPerBatch, class... Ts>
__device__ void ck_gemm(Ts... xs)
{
gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)([](auto... ys) {
ck_gemm_matrix<G>(ys...);
});
}
} // namespace migraphx
#endif
......@@ -129,6 +129,11 @@ struct index
return blockDim.x;
}
#endif
constexpr auto ngroup() const
{
return nglobal() / max_nlocal();
}
template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride)
{
......@@ -172,6 +177,12 @@ struct index
{
for_stride(local, n, nlocal(), f);
}
template <class F, class N>
__device__ void group_stride(N n, F f) const
{
for_stride(group, n, ngroup(), f);
}
};
inline __device__ __attribute__((const)) index make_index()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment