Commit 17daf766 authored by Jing Zhang's avatar Jing Zhang
Browse files

debugging

parent 95710403
...@@ -232,11 +232,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -232,11 +232,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
decltype(a_k0_m_k1_global_desc), decltype(a_k0_m_k1_global_desc),
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<2, 0, 1>,
2, //ABlockTransferSrcVectorDim, 2, // ABlockTransferSrcVectorDim,
2, 2,
1, //ABlockTransferSrcScalarPerVector, 1, // ABlockTransferSrcScalarPerVector,
1, //ABlockTransferDstScalarPerVector_KPack, 1, // ABlockTransferDstScalarPerVector_KPack,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
...@@ -259,11 +259,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -259,11 +259,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
decltype(b_k0_n_k1_global_desc), decltype(b_k0_n_k1_global_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<2, 0, 1>,
1, //BBlockTransferSrcVectorDim, 1, // BBlockTransferSrcVectorDim,
2, 2,
1, //BBlockTransferSrcScalarPerVector, 1, // BBlockTransferSrcScalarPerVector,
1, //BBlockTransferDstScalarPerVector_KPack, 1, // BBlockTransferDstScalarPerVector_KPack,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
...@@ -285,6 +285,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -285,6 +285,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
NPerBlock % (NPerWave * NRepeat) == 0, NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!"); "wrong!");
static_assert(KPack == 1, "");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc, a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}), make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
......
...@@ -83,23 +83,23 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -83,23 +83,23 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4; constexpr index_t GemmKPack = 1;
constexpr index_t MRepeat = 1; constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, GemmKPack>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 2, GemmKPack>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, GemmKPack>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<2, 4, GemmKPack>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<8, 32, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment