Commit 9bdad55b authored by Jing Zhang's avatar Jing Zhang
Browse files

debugging

parent 7084b152
...@@ -111,12 +111,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -111,12 +111,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto GemmM0 = GemmM / Number<GemmM1>{}; const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{}; const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc, out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), make_tuple(make_unmerge_transform(make_tuple(4, 2, 4)), make_pass_through_transform(N)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// out_gemm_block_cluster_desc // out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
...@@ -141,23 +140,23 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -141,23 +140,23 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format // tensor hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc, return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc, in_gemmk_gemmn_global_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc, out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks, wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks, in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks); in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
} }
......
...@@ -11,35 +11,24 @@ namespace ck { ...@@ -11,35 +11,24 @@ namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC,
class ABlockDesc, class ABlockDesc,
class BBlockDesc, class BBlockDesc,
index_t GemmMPerWave, index_t MPerWave,
index_t GemmNPerWave, index_t NPerWave,
index_t GemmKPerWave, index_t KPerWave,
index_t GemmMWaves, index_t MWaves,
index_t GemmNWaves, index_t NWaves>
index_t GemmDataPerReadA, // \todo unused parameter, remove
index_t GemmDataPerReadB // \todo unused parameter, remove
>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{ {
struct MatrixIndex
{ using CIndex = MultiIndex<2>;
index_t row;
index_t col;
};
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto XdlopsGemm = static constexpr auto XdlopsGemm = XdlopsGemm_t<float, MPerWave, NPerWave, KPerWave>{};
XdlopsGemm_t<float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
...@@ -55,7 +44,45 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -55,7 +44,45 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
return XdlopsGemm.GetOutputLayout().GetBlkSize(); return XdlopsGemm.GetOutputLayout().GetBlkSize();
} }
__device__ static auto CalculateAThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
return make_tuple(0, waveId_m * MPerWave + laneId);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
return make_tuple(0, waveId_n * NPerWave + laneId);
}
template <index_t AStride = MPerWave, index_t BStride = NPerWave>
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t blk_i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = XdlopsGemm.GetBeginOfThreadBlk(blk_i);
const index_t row = (waveId / NWaves) * AStride + thread_mtx_on_blk.row;
const index_t col = (waveId % NWaves) * BStride + thread_mtx_on_blk.col;
return CIndex{row, col};
}
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
{ {
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
...@@ -66,18 +93,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -66,18 +93,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
constexpr index_t M = ABlockDesc{}.GetLength(I1); // A is transposed constexpr index_t M = ABlockDesc{}.GetLength(I1); // A is transposed
constexpr index_t N = BBlockDesc{}.GetLength(I1); constexpr index_t N = BBlockDesc{}.GetLength(I1);
static_assert(GemmMPerWave * GemmMWaves == M, "GemmMWaves * GemmMPerWave != M"); static_assert(MPerWave * MWaves == M, "GemmMWaves * MPerWave != M");
static_assert(GemmNPerWave * GemmNWaves == N, "GemmNWaves * GemmNPerWave != N"); static_assert(NPerWave * NWaves == N, "GemmNWaves * NPerWave != N");
static_assert(BlockSize == GemmMWaves * GemmNWaves * WaveSize,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize\n");
const index_t waveId = get_thread_local_1d_id() / WaveSize; static_assert(BlockSize == MWaves * NWaves * WaveSize,
const index_t waveId_m = waveId / GemmNWaves; "BlockSize != MWaves * NWaves * WaveSize\n");
const index_t waveId_n = waveId % GemmNWaves;
mMyWaveOffsetA = waveId_m * GemmMPerWave;
mMyWaveOffsetB = waveId_n * GemmNPerWave;
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -90,73 +110,41 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -90,73 +110,41 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
auto b_thread_buf = auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
#if 0 constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, static_for<0, KPerBlock, KPerWave>{}([&](auto k) {
FloatB,
FloatC,
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<GemmKPerWave>,
Sequence<M0_, M1PerThread>,
Sequence<N0_, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
static_for<0, K, GemmKPerWave>{}([&](auto k) {
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0),
a_thread_buf); a_thread_buf);
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0),
b_thread_buf); b_thread_buf);
threadwise_gemm.Run(a_thread_buf, XdlopsGemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf);
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
}); });
#endif
}
template <index_t AStride = GemmMPerWave, index_t BStride = GemmNPerWave>
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = XdlopsGemm.GetBeginOfThreadBlk(i);
const index_t col = (waveId % GemmNWaves) * BStride + thread_mtx_on_blk.col;
const index_t row = (waveId / GemmNWaves) * AStride + thread_mtx_on_blk.row;
return MatrixIndex{row, col};
} }
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ =
make_tuple(Number<GemmKPerWave>{}, Number<1>{})); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerWave>{}, Number<1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ =
make_tuple(Number<GemmKPerWave>{}, Number<1>{})); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerWave>{}, Number<1>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<GemmKPerWave, 1>, Sequence<KPerWave, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
...@@ -166,14 +154,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -166,14 +154,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatB, FloatB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<GemmKPerWave, 1>, Sequence<KPerWave, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
1>; 1>;
// AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
// BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
}; };
} // namespace ck } // namespace ck
......
...@@ -32,7 +32,7 @@ __global__ void ...@@ -32,7 +32,7 @@ __global__ void
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc, const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc, const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc, const CGlobalDesc c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc c_block_cluster_desc) const CBlockClusterDesc c_block_cluster_desc)
{ {
GridwiseGemm::Run(p_a_global, GridwiseGemm::Run(p_a_global,
...@@ -40,7 +40,7 @@ __global__ void ...@@ -40,7 +40,7 @@ __global__ void
p_c_global, p_c_global,
a_k_m_global_desc, a_k_m_global_desc,
b_k_n_global_desc, b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_m2_n_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc, const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc) const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
...@@ -78,8 +78,8 @@ __global__ void ...@@ -78,8 +78,8 @@ __global__ void
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc); *reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc);
const auto b_k_n_global_desc = const auto b_k_n_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc); *reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc = const auto c_m0_m1_m2_n_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc); *reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
const auto c_block_cluster_desc = const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc); *reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
...@@ -89,7 +89,7 @@ __global__ void ...@@ -89,7 +89,7 @@ __global__ void
p_c_global, p_c_global,
a_k_m_global_desc, a_k_m_global_desc,
b_k_n_global_desc, b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_m2_n_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -174,7 +174,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -174,7 +174,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc, const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -190,7 +190,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -190,7 +190,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize()); p_b_global, b_k_n_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize()); p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
...@@ -309,23 +309,20 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -309,23 +309,20 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto c_m0_m1_n0_n1_thread_desc = // constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( // make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); // Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm = BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize, const auto blockwise_gemm = BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc,
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
64, // MPerWave, 64, // MPerWave,
64, // NPerWave, 64, // NPerWave,
KPerBlock, 1, // KPerWave,
2, // MWaves, 1, // MWaves,
2, // NWaves, 1 // NWaves,
1, // GemmDataPerReadM,
1 // GemmDataPerReadN
>{}; >{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -339,13 +336,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -339,13 +336,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>( // auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); // c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, // ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc), // decltype(c_m0_m1_n0_n1_thread_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{} // Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
vector_type<float, 64> c_thread_buf;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -474,43 +473,70 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -474,43 +473,70 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
#if 0
// output: register to global memory // output: register to global memory
{ {
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}; StaticBuffer<AddressSpace::Vgpr, float, 64> c_thread_buf_;
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor static_for<0, 64, 1>{}(
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; [&](auto i) { c_thread_buf_(i) = c_thread_buf.template AsType<float>()[i]; });
const auto c_thread_data_idx_on_block = constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id()); constexpr index_t K0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1();
ThreadwiseDynamicTensorSliceTransfer_v1r3< constexpr index_t K2 = OutputLayout.M0();
FloatAcc,
FloatC, static_assert(K0 == 4 && K1 == 2 && K2 == 4, "");
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), constexpr auto c_m0_m1_m2_n_thread_desc =
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>, make_dynamic_naive_tensor_descriptor_packed_v2(
CThreadTransferSrcDstAccessOrder, make_tuple(Number<K0>{}, Number<1>{}, Number<K2>{}, Number<1>{}));
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, constexpr index_t BlkSize = OutputLayout.GetBlkSize();
CGlobalMemoryDataOperation, constexpr index_t NumBlks = OutputLayout.GetNumBlks();
1,
true>{ static_assert(BlkSize == 16 && NumBlks == 4, "");
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0], // force unrolling the output loop to get ride of scratches
c_thread_data_idx_on_block[I1], for(index_t i = 0; i < NumBlks; ++i)
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2], {
c_thread_data_idx_on_block[I3])} // calculate origin of thread output tensor on global memory
.Run(c_m0_m1_n0_n1_thread_desc, // blockwise GEMM c matrix starting index
make_tuple(I0, I0, I0, I0), const auto c_thread_mtx_on_block =
c_thread_buf, blockwise_gemm.CalculateCThreadOriginDataIndex(i);
c_m0_m1_n0_n1_global_desc,
c_global_buf, const index_t k_thread_data_on_global =
c_m0_m1_n0_n1_global_tensor_iterator_hacks); m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<K0, 1, K2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (K2 * K1),
k_thread_data_on_global % (K2 * K1) / K2,
k_thread_data_on_global % K2,
b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
} }
#endif
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
...@@ -519,7 +545,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -519,7 +545,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc, const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
...@@ -533,7 +559,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -533,7 +559,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
p_c_global, p_c_global,
a_k_m_global_desc, a_k_m_global_desc,
b_k_n_global_desc, b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_m2_n_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
......
...@@ -50,20 +50,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32> ...@@ -50,20 +50,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
static constexpr index_t cycles = 64; static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1; static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const float*>(a); return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::run(a, b, reg_c);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
} }
}; };
...@@ -557,11 +547,7 @@ struct xdlops_info ...@@ -557,11 +547,7 @@ struct xdlops_info
static constexpr auto OutputVecType = OutputVecType_{}; static constexpr auto OutputVecType = OutputVecType_{};
}; };
template <class data_type, template <class data_type, index_t MPerWave, index_t NPerWave, index_t KPerWave>
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct XdlopsGemm_t struct XdlopsGemm_t
{ {
struct MatrixIndex struct MatrixIndex
...@@ -585,8 +571,6 @@ struct XdlopsGemm_t ...@@ -585,8 +571,6 @@ struct XdlopsGemm_t
MPerXdlops == 64, MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk"); "m != num_input_blks * num_regs_blk");
...@@ -604,187 +588,17 @@ struct XdlopsGemm_t ...@@ -604,187 +588,17 @@ struct XdlopsGemm_t
return MPerXdlops * NPerXdlops / mfma_type.wave_size; return MPerXdlops * NPerXdlops / mfma_type.wave_size;
} }
#if CK_USE_AMD_XDLOPS_EMULATE template <class FloatA, class FloatB, class FloatC>
// emulate xdlops __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ FloatC XdlopsEmulate(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC p_c_thread) const
{
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
// K reduction
static_if<IsKReduction>{}([&](auto) {
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = (k + n) * M;
index_t b_off = (k + n) * N;
index_t c_off = 0;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex = m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] += inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
})
.Else([&](auto) {
static_if<IsABroadcast>{}([&](auto) {
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
{
for(index_t n_i = 0; n_i < NRepeats; ++n_i)
{
// ABroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < MPerXdlops / mfma_type.m; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + b * mfma_type.m + MPerXdlops * m_i;
index_t b_off = k * N + n * mfma_type.num_threads_blk +
NPerXdlops * n_i;
index_t c_off =
n * mfma_type.num_regs_blk +
b * mfma_type.num_regs_xdlops +
(NRepeats * m_i + n_i) * GetRegSizePerXdlops();
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex = m % mfma_type.group_size +
blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size *
mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] +=
inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off],
p_b_wave[bindex + b_off]);
}
}
}
}
}
}
})
.Else([&](auto) {
// BBroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < NPerXdlops / mfma_type.n; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + n * mfma_type.m;
index_t b_off = k * N + b * mfma_type.n;
index_t c_off =
n * mfma_type.num_regs_blk + b * mfma_type.num_regs_xdlops;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size +
blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] +=
inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}
});
});
return p_c_thread;
}
#endif
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ FloatC Run(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC p_c_thread) const
{ {
static_assert(is_same<FloatA, FloatB>::value, "FloatA != FloatB");
static_assert(is_same<data_type, float>::value || is_same<data_type, half_t>::value || static_assert(is_same<data_type, float>::value || is_same<data_type, half_t>::value ||
is_same<data_type, ushort>::value, is_same<data_type, ushort>::value,
"base data_type must be float, half, ushort!"); "base data_type must be float, half, ushort!");
#if CK_USE_AMD_XDLOPS_EMULATE static_for<0, KPerWave, mfma_type.k_base>{}([&](auto k_i) {
p_c_thread = XdlopsEmulate<M, N, K>(p_a_wave, p_b_wave, p_c_thread); mfma_type.template run<MPerXdlops, NPerXdlops>(
#else p_a_wave[Number<k_i>{}], p_b_wave[Number<k_i>{}], p_c_thread);
});
constexpr index_t KPACT = sizeof(FloatA) / sizeof(data_type);
static_assert(KPACT % mfma_type.k_base == 0, "wrong! KPACT is not supported by mfma");
constexpr index_t KRepeats = KPACT / mfma_type.k_base;
static_assert(!IsKReduction || K % mfma_type.num_input_blks == 0,
"K cannot divided by mfma_type.num_input_blks!");
constexpr index_t KPerThread = IsKReduction ? K / mfma_type.num_input_blks : K;
static_assert(!IsKReduction || (MRepeats == 1 && NRepeats == 1),
"KReduction does not support M/N Repeats!");
FloatA a[KPerThread * MRepeats];
FloatB b[KPerThread * NRepeats];
auto pa = reinterpret_cast<const data_type*>(&a);
auto pb = reinterpret_cast<const data_type*>(&b);
constexpr index_t AStride = KPerThread * KRepeats;
constexpr index_t BStride = KPerThread * KRepeats;
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
static_if<!IsKReduction>{}([&](auto) {
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
for(index_t k_i = 0; k_i < KPerThread; ++k_i)
a[k_i + m_i * KPerThread] = p_a_wave[k_i * M + laneId + MPerXdlops * m_i];
for(index_t n_i = 0; n_i < NRepeats; ++n_i)
for(index_t k_i = 0; k_i < KPerThread; ++k_i)
b[k_i + n_i * KPerThread] = p_b_wave[k_i * N + laneId + NPerXdlops * n_i];
})
.Else([&](auto) {
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
for(index_t k_i = 0; k_i < KPerThread; ++k_i)
{
a[k_i] = p_a_wave[(k_i * mfma_type.num_input_blks + blk_id) * M + blk_td];
b[k_i] = p_b_wave[(k_i * mfma_type.num_input_blks + blk_id) * N + blk_td];
}
});
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for(index_t k_i = 0; k_i < KPerThread * KRepeats; ++k_i)
{
p_c_thread =
mfma_type
.template run<MPerXdlops * MRepeats, NPerXdlops * NRepeats, AStride, BStride>(
&pa[k_i * mfma_type.k_base], &pb[k_i * mfma_type.k_base], p_c_thread);
}
#endif
return p_c_thread;
} }
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i) __device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
...@@ -821,8 +635,8 @@ struct XdlopsGemm_t ...@@ -821,8 +635,8 @@ struct XdlopsGemm_t
} }
template <class data_type_ = data_type, template <class data_type_ = data_type,
index_t MPerWave_ = GemmMPerWave, index_t MPerWave_ = MPerWave,
index_t NPerWave_ = GemmNPerWave> index_t NPerWave_ = NPerWave>
static constexpr auto GetXdlopsInfo(); static constexpr auto GetXdlopsInfo();
template <> template <>
......
...@@ -198,78 +198,79 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( ...@@ -198,78 +198,79 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
template <index_t AStride, index_t BStride> // template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<128, 64, AStride, BStride> // struct intrin_mfma_f32_32x32x1f32<128, 64, AStride, BStride>
{ //{
__device__ static c_vec32_4_t::VecType //__device__ static c_vec32_4_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c) // run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
{ //{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); // reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); // reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z = // reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0); // llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w = // reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0); // llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c; // return reg_c;
} //}
}; //};
template <index_t AStride, index_t BStride> // template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride> // struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
{ //{
__device__ static c_vec32_4_t::VecType //__device__ static c_vec32_4_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c) // run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
{ //{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); // reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); // reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z = // reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0); // llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w = // reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0); // llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c; // return reg_c;
} //}
}; //};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride> template <>
struct intrin_mfma_f32_32x32x1f32<64, 32, AStride, BStride> struct intrin_mfma_f32_32x32x1f32<64, 64>
{ {
__device__ static c_vec32_1_t::VecType __device__ static void
run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c) run(const float& reg_a, const float& reg_b, vector_type<float, 64>& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
return reg_c; reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
} }
}; };
template <index_t AStride, index_t BStride> // template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride> // struct intrin_mfma_f32_32x32x1f32<64, 32, AStride, BStride>
{ //{
__device__ static c_vec32_1_t::VecType //__device__ static c_vec32_1_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c) // run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
{ //{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); // reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c; // return reg_c;
} //}
}; //};
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride>
//{
//__device__ static c_vec32_1_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// return reg_c;
//}
//};
__device__ c_vec16_1_t::VecType __device__ c_vec16_1_t::VecType
intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c) intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c)
......
...@@ -77,12 +77,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -77,12 +77,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1 // b thread copy 4x1
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmMPerThread = 4;
...@@ -91,17 +90,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -91,17 +90,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8; constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmNLevel1Cluster = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
......
...@@ -25,11 +25,11 @@ int main(int argc, char* argv[]) ...@@ -25,11 +25,11 @@ int main(int argc, char* argv[])
using namespace ck; using namespace ck;
#if 1 #if 1
constexpr index_t N = 8; constexpr index_t N = 4;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 4; constexpr index_t HI = 4;
constexpr index_t WI = 4; constexpr index_t WI = 4;
constexpr index_t K = 128; constexpr index_t K = 64;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -64,8 +64,8 @@ int main(int argc, char* argv[]) ...@@ -64,8 +64,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -688,7 +688,7 @@ int main(int argc, char* argv[]) ...@@ -688,7 +688,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment