Commit df6dd915 authored by Jing Zhang's avatar Jing Zhang
Browse files

formating

parent e9f05865
...@@ -158,7 +158,5 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw ...@@ -158,7 +158,5 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
} }
}; };
} // namespace ck } // namespace ck
#endif #endif
...@@ -26,14 +26,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -26,14 +26,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
index_t col; index_t col;
}; };
//static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}; // static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave,
// GemmDataPerReadA, GemmDataPerReadB>{};
index_t mMyWaveOffsetA; index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB; index_t mMyWaveOffsetB;
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.GetOutputLayout(); } __device__ constexpr auto GetOutputLayout() const
{
return XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.GetOutputLayout();
}
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops() __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{ {
...@@ -67,7 +72,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -67,7 +72,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
constexpr index_t N = BlockMatrixB::NCol(); constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow(); constexpr index_t K = BlockMatrixA::NRow();
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.template Run<M, N, K>( XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread); &p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
} }
...@@ -76,7 +82,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -76,7 +82,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const index_t waveId = get_thread_local_1d_id() / WaveSize; const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.GetBeginOfThreadBlk(i); const auto thread_mtx_on_blk =
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.GetBeginOfThreadBlk(i);
const index_t col = waveId % GemmNWaves * GemmNPerWave + thread_mtx_on_blk.col; const index_t col = waveId % GemmNWaves * GemmNPerWave + thread_mtx_on_blk.col;
...@@ -94,14 +102,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -94,14 +102,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
__device__ void XdlopsMatrixCSetZero() const __device__ void XdlopsMatrixCSetZero() const
{ {
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize; constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.SetZeroXdlopsRegs(Number<thread_mtx_size>{}); XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.SetZeroXdlopsRegs(Number<thread_mtx_size>{});
} }
template <class FloatC> template <class FloatC>
__device__ void XdlopsMatrixCRead(FloatC* __restrict__ p_c_thread) const __device__ void XdlopsMatrixCRead(FloatC* __restrict__ p_c_thread) const
{ {
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize; constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.ReadXdlopsRegs(Number<thread_mtx_size>{}, p_c_thread); XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.ReadXdlopsRegs(Number<thread_mtx_size>{}, p_c_thread);
} }
}; };
......
...@@ -117,7 +117,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -117,7 +117,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
{ {
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer, p_block_dst); mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer,
p_block_dst);
} }
else else
{ {
......
...@@ -497,174 +497,171 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16> ...@@ -497,174 +497,171 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
} }
}; };
template <class data_type, template <class data_type, index_t MPerWave, index_t NPerWave>
index_t MPerWave, __device__ constexpr auto GetMFMAInfo();
index_t NPerWave>
__device__ constexpr auto GetMFMAInfo();
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 32, 64>() __device__ constexpr auto GetMFMAInfo<float, 32, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{}; return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 64>() __device__ constexpr auto GetMFMAInfo<float, 64, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{}; return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 32>() __device__ constexpr auto GetMFMAInfo<float, 64, 32>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{}; return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 32, 32>() __device__ constexpr auto GetMFMAInfo<float, 32, 32>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x2xf32>{}; return mfma_info<mfma_instr::mfma_f32_32x32x2xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 16, 16>() __device__ constexpr auto GetMFMAInfo<float, 16, 16>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x4xf32>{}; return mfma_info<mfma_instr::mfma_f32_16x16x4xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 16, 64>() __device__ constexpr auto GetMFMAInfo<float, 16, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{}; return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 16>() __device__ constexpr auto GetMFMAInfo<float, 64, 16>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{}; return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 8, 64>() __device__ constexpr auto GetMFMAInfo<float, 8, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{}; return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<float, 4, 64>() __device__ constexpr auto GetMFMAInfo<float, 4, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{}; return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 64>() __device__ constexpr auto GetMFMAInfo<half, 64, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 32>() __device__ constexpr auto GetMFMAInfo<half, 64, 32>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 32, 64>() __device__ constexpr auto GetMFMAInfo<half, 32, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 32, 32>() __device__ constexpr auto GetMFMAInfo<half, 32, 32>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x8f16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x8f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 16, 16>() __device__ constexpr auto GetMFMAInfo<half, 16, 16>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x16f16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x16f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 16, 64>() __device__ constexpr auto GetMFMAInfo<half, 16, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 16>() __device__ constexpr auto GetMFMAInfo<half, 64, 16>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 4, 64>() __device__ constexpr auto GetMFMAInfo<half, 4, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{}; return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 8, 64>() __device__ constexpr auto GetMFMAInfo<half, 8, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{}; return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 64>() __device__ constexpr auto GetMFMAInfo<ushort, 64, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 32>() __device__ constexpr auto GetMFMAInfo<ushort, 64, 32>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 32, 64>() __device__ constexpr auto GetMFMAInfo<ushort, 32, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 32, 32>() __device__ constexpr auto GetMFMAInfo<ushort, 32, 32>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x4bf16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x4bf16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 16, 16>() __device__ constexpr auto GetMFMAInfo<ushort, 16, 16>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x8bf16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x8bf16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 16, 64>() __device__ constexpr auto GetMFMAInfo<ushort, 16, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 16>() __device__ constexpr auto GetMFMAInfo<ushort, 64, 16>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 4, 64>() __device__ constexpr auto GetMFMAInfo<ushort, 4, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{}; return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<ushort, 8, 64>() __device__ constexpr auto GetMFMAInfo<ushort, 8, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{}; return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{};
} }
template <class data_type, template <class data_type,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
...@@ -685,7 +682,10 @@ struct XdlopsGemm_t ...@@ -685,7 +682,10 @@ struct XdlopsGemm_t
__device__ static constexpr index_t M0() { return M0_; } __device__ static constexpr index_t M0() { return M0_; }
__device__ static constexpr index_t N1() { return N1_; } __device__ static constexpr index_t N1() { return N1_; }
__device__ static constexpr index_t N0() { return N0_; } __device__ static constexpr index_t N0() { return N0_; }
__device__ static constexpr index_t GetBlkSize() { return GetMFMAInfo<data_type, MPerWave, NPerWave>().num_regs_blk; } __device__ static constexpr index_t GetBlkSize()
{
return GetMFMAInfo<data_type, MPerWave, NPerWave>().num_regs_blk;
}
__device__ static constexpr index_t GetNumBlks() __device__ static constexpr index_t GetNumBlks()
{ {
...@@ -726,7 +726,6 @@ struct XdlopsGemm_t ...@@ -726,7 +726,6 @@ struct XdlopsGemm_t
return mfma_type.num_output_blks == 1 && mfma_type.num_input_blks != 1; return mfma_type.num_output_blks == 1 && mfma_type.num_input_blks != 1;
} }
#if CK_USE_AMD_XDLOPS_EMULATE #if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops // emulate xdlops
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC> template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
...@@ -843,7 +842,8 @@ struct XdlopsGemm_t ...@@ -843,7 +842,8 @@ struct XdlopsGemm_t
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread); mfma_type.run(
Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
} }
}).Else([&](auto) { }).Else([&](auto) {
...@@ -852,7 +852,8 @@ struct XdlopsGemm_t ...@@ -852,7 +852,8 @@ struct XdlopsGemm_t
for(index_t k = 0; k < K; k += mfma_type.num_input_blks) for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{ {
mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread); mfma_type.run(
Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
} }
}); });
...@@ -898,7 +899,7 @@ struct XdlopsGemm_t ...@@ -898,7 +899,7 @@ struct XdlopsGemm_t
__device__ void SetZeroXdlopsRegs(Number<Size>) const __device__ void SetZeroXdlopsRegs(Number<Size>) const
{ {
#if !CK_USE_AMD_XDLOPS_EMULATE #if !CK_USE_AMD_XDLOPS_EMULATE
//gcnasm_accvgpr_zero<Size>(); // gcnasm_accvgpr_zero<Size>();
#endif #endif
} }
...@@ -907,8 +908,8 @@ struct XdlopsGemm_t ...@@ -907,8 +908,8 @@ struct XdlopsGemm_t
{ {
#if !CK_USE_AMD_XDLOPS_EMULATE #if !CK_USE_AMD_XDLOPS_EMULATE
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>(); constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
//gcnasm_nop<mfma_type.cycles>(); // gcnasm_nop<mfma_type.cycles>();
//gcnasm_accvgpr_read<Size>(p_c_thread); // gcnasm_accvgpr_read<Size>(p_c_thread);
#else #else
(void)p_c_thread; (void)p_c_thread;
#endif #endif
......
...@@ -7,7 +7,8 @@ template <index_t MPerWave, index_t NPerWave> ...@@ -7,7 +7,8 @@ template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x1f32(const float&, const float&, float32_t*); __device__ void gcnasm_mfma_f32_32x32x1f32(const float&, const float&, float32_t*);
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c) __device__ void
gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{ {
auto reg_c_ = reinterpret_cast<float_t*>(reg_c); auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 32; i++) for(index_t i = 0; i < 32; i++)
...@@ -17,7 +18,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const flo ...@@ -17,7 +18,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const flo
} }
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c) __device__ void
gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{ {
auto reg_c_ = reinterpret_cast<float_t*>(reg_c); auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++) for(index_t i = 0; i < 16; i++)
...@@ -27,7 +29,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const flo ...@@ -27,7 +29,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const flo
} }
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const float& reg_b, float32_t* reg_c) __device__ void
gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{ {
auto reg_c_ = reinterpret_cast<float_t*>(reg_c); auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++) for(index_t i = 0; i < 16; i++)
...@@ -53,12 +56,14 @@ template <index_t MPerWave, index_t NPerWave> ...@@ -53,12 +56,14 @@ template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x1f32(const float&, const float&, float16_t*); __device__ void gcnasm_mfma_f32_16x16x1f32(const float&, const float&, float16_t*);
template <> template <>
__device__ void gcnasm_mfma_f32_16x16x1f32<16, 64>(const float& reg_a, const float& reg_b, float16_t* reg_c) __device__ void
gcnasm_mfma_f32_16x16x1f32<16, 64>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_16x16x1f32<64, 16>(const float& reg_a, const float& reg_b, float16_t* reg_c) __device__ void
gcnasm_mfma_f32_16x16x1f32<64, 16>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{ {
} }
...@@ -66,66 +71,77 @@ template <index_t MPerWave, index_t NPerWave> ...@@ -66,66 +71,77 @@ template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x1f32(const float& reg_a, const float& reg_b, float4_t* reg_c); __device__ void gcnasm_mfma_f32_4x4x1f32(const float& reg_a, const float& reg_b, float4_t* reg_c);
template <> template <>
__device__ void gcnasm_mfma_f32_4x4x1f32<4, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c) __device__ void
gcnasm_mfma_f32_4x4x1f32<4, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_4x4x1f32<8, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c) __device__ void
gcnasm_mfma_f32_4x4x1f32<8, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{ {
} }
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x4f16(const half4_t&, __device__ void gcnasm_mfma_f32_32x32x4f16(const half4_t&, const half4_t&, float32_t*);
const half4_t&,
float32_t*);
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c) __device__ void
gcnasm_mfma_f32_32x32x4f16<64, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<32, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c) __device__ void
gcnasm_mfma_f32_32x32x4f16<32, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64, 32>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c) __device__ void
gcnasm_mfma_f32_32x32x4f16<64, 32>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{ {
} }
__device__ void gcnasm_mfma_f32_32x32x8f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c) __device__ void
gcnasm_mfma_f32_32x32x8f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{ {
} }
__device__ void gcnasm_mfma_f32_16x16x16f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c) __device__ void
gcnasm_mfma_f32_16x16x16f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{ {
} }
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x4f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c); __device__ void
gcnasm_mfma_f32_16x16x4f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c);
template <> template <>
__device__ void gcnasm_mfma_f32_16x16x4f16<16, 64>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c) __device__ void
gcnasm_mfma_f32_16x16x4f16<16, 64>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_16x16x4f16<64, 16>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c) __device__ void
gcnasm_mfma_f32_16x16x4f16<64, 16>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{ {
} }
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x4f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c); __device__ void
gcnasm_mfma_f32_4x4x4f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c);
template <> template <>
__device__ void gcnasm_mfma_f32_4x4x4f16<4, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c) __device__ void
gcnasm_mfma_f32_4x4x4f16<4, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_4x4x4f16<8, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c) __device__ void
gcnasm_mfma_f32_4x4x4f16<8, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{ {
} }
...@@ -133,54 +149,69 @@ template <index_t MPerWave, index_t NPerWave> ...@@ -133,54 +149,69 @@ template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x2bf16(const ushort2_t&, const ushort2_t&, float32_t*); __device__ void gcnasm_mfma_f32_32x32x2bf16(const ushort2_t&, const ushort2_t&, float32_t*);
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c) __device__ void gcnasm_mfma_f32_32x32x2bf16<64, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<32, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c) __device__ void gcnasm_mfma_f32_32x32x2bf16<32, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 32>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c) __device__ void gcnasm_mfma_f32_32x32x2bf16<64, 32>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{ {
} }
__device__ void gcnasm_mfma_f32_32x32x4bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c) __device__ void
gcnasm_mfma_f32_32x32x4bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{ {
} }
__device__ void gcnasm_mfma_f32_16x16x8bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c) __device__ void
gcnasm_mfma_f32_16x16x8bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{ {
} }
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c); __device__ void
gcnasm_mfma_f32_16x16x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c);
template <> template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c) __device__ void gcnasm_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float16_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c) __device__ void gcnasm_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float16_t* reg_c)
{ {
} }
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c); __device__ void
gcnasm_mfma_f32_4x4x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c);
template <> template <>
__device__ void gcnasm_mfma_f32_4x4x2bf16<4, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c) __device__ void
gcnasm_mfma_f32_4x4x2bf16<4, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{ {
} }
template <> template <>
__device__ void gcnasm_mfma_f32_4x4x2bf16<8, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c) __device__ void
gcnasm_mfma_f32_4x4x2bf16<8, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{ {
} }
// clang-format on // clang-format on
} }
#endif #endif
...@@ -31,5 +31,4 @@ ...@@ -31,5 +31,4 @@
#include "amd_xdlops_emulate.hpp" #include "amd_xdlops_emulate.hpp"
#endif #endif
#endif #endif
...@@ -858,7 +858,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -858,7 +858,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
......
...@@ -1048,7 +1048,6 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -1048,7 +1048,6 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
......
...@@ -85,7 +85,8 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc, ...@@ -85,7 +85,8 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw< constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw<
GridSize, GridSize,
BlockSize, BlockSize,
T, T,
...@@ -161,7 +162,6 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc, ...@@ -161,7 +162,6 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
......
...@@ -20,26 +20,495 @@ ...@@ -20,26 +20,495 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" //#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0
// 1x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 160;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35
constexpr index_t N = 128;
constexpr index_t C = 96;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 7x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 320;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 0
// 1x7, 17x17
constexpr index_t N = 128;
constexpr index_t C = 224;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 224;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
// 3x3, 299x299 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 299;
constexpr index_t WI = 299;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 149;
constexpr index_t WI = 149;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 17x17, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 35x35
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x3, 8x8
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 448;
constexpr index_t Y = 1;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 1>;
using RightPads = Sequence<0, 1>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 1
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 3x3, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14 // 1x1, 14x14
constexpr index_t N = 64; constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
constexpr index_t HI = 14; constexpr index_t HI = 14;
constexpr index_t WI = 14; constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 28x28
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 3x3, 14x14
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 56x56, stride 2
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x7, 230x230 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 230;
constexpr index_t WI = 230;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride = 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1024; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#endif
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
...@@ -133,6 +602,18 @@ int main(int argc, char* argv[]) ...@@ -133,6 +602,18 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 1 #elif 1
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment