Commit 1c80f924 authored by Paul's avatar Paul
Browse files

Fix bug in b tensor

parent f5d3f160
...@@ -278,6 +278,12 @@ constexpr auto make_const_array(T x, Ts... xs) ...@@ -278,6 +278,12 @@ constexpr auto make_const_array(T x, Ts... xs)
return integral_const_array<typename T::value_type, x, xs...>{}; return integral_const_array<typename T::value_type, x, xs...>{};
} }
template <class T, T... Xs, class F>
constexpr auto unpack(integral_const_array<T, Xs...>, F f)
{
return f(_c<Xs>...);
}
template <class T, T... Xs, class F> template <class T, T... Xs, class F>
constexpr auto transform(integral_const_array<T, Xs...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, F f)
{ {
......
...@@ -34,13 +34,24 @@ ...@@ -34,13 +34,24 @@
namespace migraphx { namespace migraphx {
template<class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) {
return make_const_array(n, k);
});
}
template<class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens), ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class E, class A, class B, class... Ds> template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm(E e, A a, B b, Ds... ds) __device__ void ck_gemm(E e, A a, B b, Ds... ds)
{ {
constexpr const G gemm{}; constexpr const G gemm{};
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>()); constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<B>()); constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>()); constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
constexpr const auto ds_grid_desc_m_n = constexpr const auto ds_grid_desc_m_n =
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...); ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
...@@ -60,7 +71,7 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds) ...@@ -60,7 +71,7 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds)
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
MIGRAPHX_CK_STATIC_ASSERT(G::CheckValidity( MIGRAPHX_CK_STATIC_ASSERT(GridwiseGemm::CheckValidity(
a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map)); a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map));
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment