Commit e734215e authored by Paul's avatar Paul
Browse files

Update ck version

parent e0724664
...@@ -28,4 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2 ...@@ -28,4 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@f89f3440b2c98f05ab61c1290750f3bf5bff7e6c -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@506798ded0673b7473c0e22e32fb8ad1b6c29489 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
#ifndef MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/index.hpp>
namespace migraphx {
template <class Tensor>
constexpr auto gemm_get_batches()
{
constexpr auto lens = get_shape_c<Tensor>{}.lens;
constexpr auto strides = get_shape_c<Tensor>{}.strides;
constexpr auto new_lens = sequence(
lens.size() - _c<2>, [&](auto... is) { return make_const_array(_c<lens[is]>...); });
constexpr auto new_strides = sequence(
strides.size() - _c<2>, [&](auto... is) { return make_const_array(_c<strides[is]>...); });
return make_shape(new_lens, new_strides);
}
template <class Tensor>
constexpr auto gemm_get_matrix()
{
constexpr auto lens = get_shape_c<Tensor>{}.lens;
constexpr auto strides = get_shape_c<Tensor>{}.strides;
constexpr auto m = lens.size() - _c<2>;
constexpr auto n = lens.size() - _c<1>;
constexpr auto new_lens = make_const_array(_c<lens[m]>, _c<lens[n]>);
constexpr auto new_strides = make_const_array(_c<strides[m]>, _c<strides[n]>);
return make_shape(new_lens, new_strides);
}
template <class Tensor, class T>
constexpr auto gemm_batch_slice(Tensor t, T i)
{
constexpr auto batch = gemm_get_batches<Tensor>();
constexpr auto matrix = gemm_get_matrix<Tensor>();
return make_tensor_view(t.data() + batch.index(i), matrix);
}
template <class BlocksPerBatch, class T, class... Ts>
constexpr auto gemm_batch_args(index idx, BlocksPerBatch bpb, T x, Ts... xs)
{
return [=](auto f) {
// All tensors should have the same rank
static_assert(
(true and ... and (get_shape_c<T>{}.lens.size() == get_shape_c<Ts>{}.lens.size())));
if constexpr(get_shape_c<T>{}.lens.size() > 2)
{
// Get the first batch since all batches should have the same number of elements
constexpr auto batch = gemm_get_batches<T>();
static_assert(
(true and ... and (batch.elements() == gemm_get_batches<Ts>().elements())));
idx.group_stride(bpb * batch.elements(), [&](auto gidx) {
const auto batch_idx = gidx / bpb;
f(gemm_batch_slice(x, batch_idx), gemm_batch_slice(xs, batch_idx)...);
});
}
else
{
f(x, xs...);
}
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment