Commit 5ed0cbe4 authored by Paul's avatar Paul
Browse files

Add gemm_batcher

parent 07b8f71c
...@@ -67,7 +67,7 @@ extern "C" { ...@@ -67,7 +67,7 @@ extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<CK_DeviceGemmMultipleD<${instance}>>(xs...); ck_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
}); });
} }
...@@ -230,9 +230,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -230,9 +230,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
gemm_type += "Padding"; gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type); ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config);
hip_compile_options options; hip_compile_options options;
auto block_size = ip.get_block_size(); auto block_size = ip.get_block_size();
auto grid_size = ip.get_grid_size(config); auto grid_size = blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
options.output = c_shape; options.output = c_shape;
...@@ -246,6 +248,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -246,6 +248,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{{"instance", ip.str()}, {{"instance", ip.str()},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}}); {"kernel", options.kernel_name}});
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <migraphx/kernels/tensor_view.hpp> #include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp> #include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_gemm_includes.hpp> #include <migraphx/kernels/ck_gemm_includes.hpp>
#include <migraphx/kernels/print.hpp> #include <migraphx/kernels/gemm_batcher.hpp>
namespace migraphx { namespace migraphx {
...@@ -46,7 +46,7 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor> ...@@ -46,7 +46,7 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>
ck_transposeb_dims(get_shape_c<Tensor>{}.strides))); ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class E, class A, class B, class... Ds> template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm(E e, A a, B b, Ds... ds) __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
{ {
constexpr const G gemm{}; constexpr const G gemm{};
...@@ -95,5 +95,13 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds) ...@@ -95,5 +95,13 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds)
block_2_etile_map); block_2_etile_map);
} }
template <class G, index_int BlocksPerBatch, class... Ts>
__device__ void ck_gemm(Ts... xs)
{
gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)([](auto... ys) {
ck_gemm_matrix<G>(ys...);
});
}
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -129,6 +129,11 @@ struct index ...@@ -129,6 +129,11 @@ struct index
return blockDim.x; return blockDim.x;
} }
#endif #endif
constexpr auto ngroup() const
{
return nglobal() / max_nlocal();
}
template <class N, class Stride> template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride) static constexpr auto max_stride_iterations(N n, Stride stride)
{ {
...@@ -172,6 +177,12 @@ struct index ...@@ -172,6 +177,12 @@ struct index
{ {
for_stride(local, n, nlocal(), f); for_stride(local, n, nlocal(), f);
} }
template <class F, class N>
__device__ void group_stride(N n, F f) const
{
for_stride(group, n, ngroup(), f);
}
}; };
inline __device__ __attribute__((const)) index make_index() inline __device__ __attribute__((const)) index make_index()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment