Commit 711aaed9 authored by Paul's avatar Paul
Browse files

Compile fixes

parent d8a2ed68
...@@ -7,25 +7,25 @@ ...@@ -7,25 +7,25 @@
namespace migraphx { namespace migraphx {
template<class Shape> template<class Tensor>
constexpr auto gemm_get_batches() constexpr auto gemm_get_batches()
{ {
constexpr auto lens = Shape{}.lens; constexpr auto lens = get_shape_c<Tensor>{}.lens;
constexpr auto strides = Shape{}.strides; constexpr auto strides = get_shape_c<Tensor>{}.strides;
constexpr auto new_lens = sequence(lens.size() - _c<2>, [](auto... is) { constexpr auto new_lens = sequence(lens.size() - _c<2>, [&](auto... is) {
return make_const_array(_c<lens[is]>...); return make_const_array(_c<lens[is]>...);
}); });
constexpr auto new_strides = sequence(strides.size() - _c<2>, [](auto... is) { constexpr auto new_strides = sequence(strides.size() - _c<2>, [&](auto... is) {
return make_const_array(_c<strides[is]>...); return make_const_array(_c<strides[is]>...);
}); });
return make_shape(new_lens, new_strides); return make_shape(new_lens, new_strides);
} }
template<class Shape> template<class Tensor>
constexpr auto gemm_get_matrix() constexpr auto gemm_get_matrix()
{ {
constexpr auto lens = Shape{}.lens; constexpr auto lens = get_shape_c<Tensor>{}.lens;
constexpr auto strides = Shape{}.strides; constexpr auto strides = get_shape_c<Tensor>{}.strides;
constexpr auto m = lens.size() - _c<2>; constexpr auto m = lens.size() - _c<2>;
constexpr auto n = lens.size() - _c<1>; constexpr auto n = lens.size() - _c<1>;
constexpr auto new_lens = make_const_array(_c<lens[m]>, _c<lens[n]>); constexpr auto new_lens = make_const_array(_c<lens[m]>, _c<lens[n]>);
...@@ -38,7 +38,7 @@ constexpr auto gemm_batch_slice(Tensor t, T i) ...@@ -38,7 +38,7 @@ constexpr auto gemm_batch_slice(Tensor t, T i)
{ {
constexpr auto batch = gemm_get_batches<Tensor>(); constexpr auto batch = gemm_get_batches<Tensor>();
constexpr auto matrix = gemm_get_matrix<Tensor>(); constexpr auto matrix = gemm_get_matrix<Tensor>();
return make_tensor_view(t.data() + matrix.index(i), matrix); return make_tensor_view(t.data() + batch.index(i), matrix);
} }
template<class BlocksPerBatch, class T, class... Ts> template<class BlocksPerBatch, class T, class... Ts>
...@@ -53,7 +53,7 @@ constexpr auto gemm_batch_args(index idx, BlocksPerBatch bpb, T x, Ts... xs) ...@@ -53,7 +53,7 @@ constexpr auto gemm_batch_args(index idx, BlocksPerBatch bpb, T x, Ts... xs)
constexpr auto batch = gemm_get_batches<T>(); constexpr auto batch = gemm_get_batches<T>();
static_assert((true and ... and (batch.elements() == gemm_get_batches<Ts>().elements()))); static_assert((true and ... and (batch.elements() == gemm_get_batches<Ts>().elements())));
idx.group_stride(bpb * batch.elements(), [&](auto gidx) { idx.group_stride(bpb * batch.elements(), [&](auto gidx) {
constexpr auto batch_idx = gidx / bpb; const auto batch_idx = gidx / bpb;
f(gemm_batch_slice(x, batch_idx), gemm_batch_slice(xs, batch_idx)...); f(gemm_batch_slice(x, batch_idx), gemm_batch_slice(xs, batch_idx)...);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment