Commit d43d9e1f authored by Paul's avatar Paul
Browse files

Format

parent 1c80f924
......@@ -34,16 +34,15 @@
namespace migraphx {
template<class Dims>
template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) {
return make_const_array(n, k);
});
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
}
template<class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens), ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm(E e, A a, B b, Ds... ds)
......@@ -51,7 +50,8 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds)
constexpr const G gemm{};
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
constexpr const auto b_grid_desc_n_k =
gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
constexpr const auto ds_grid_desc_m_n =
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment