Commit df6dd915 authored by Jing Zhang's avatar Jing Zhang
Browse files

formating

parent e9f05865
......@@ -158,7 +158,5 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
}
};
} // namespace ck
#endif
......@@ -26,14 +26,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
index_t col;
};
//static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
// static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave,
// GemmDataPerReadA, GemmDataPerReadB>{};
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
static constexpr index_t WaveSize = 64;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.GetOutputLayout(); }
__device__ constexpr auto GetOutputLayout() const
{
return XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.GetOutputLayout();
}
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{
......@@ -67,7 +72,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.template Run<M, N, K>(
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
}
......@@ -76,7 +82,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.GetBeginOfThreadBlk(i);
const auto thread_mtx_on_blk =
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.GetBeginOfThreadBlk(i);
const index_t col = waveId % GemmNWaves * GemmNPerWave + thread_mtx_on_blk.col;
......@@ -94,14 +102,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
__device__ void XdlopsMatrixCSetZero() const
{
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.SetZeroXdlopsRegs(Number<thread_mtx_size>{});
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.SetZeroXdlopsRegs(Number<thread_mtx_size>{});
}
template <class FloatC>
__device__ void XdlopsMatrixCRead(FloatC* __restrict__ p_c_thread) const
{
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.ReadXdlopsRegs(Number<thread_mtx_size>{}, p_c_thread);
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}
.ReadXdlopsRegs(Number<thread_mtx_size>{}, p_c_thread);
}
};
......
......@@ -117,7 +117,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer, p_block_dst);
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer,
p_block_dst);
}
else
{
......
......@@ -497,174 +497,171 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
}
};
template <class data_type,
index_t MPerWave,
index_t NPerWave>
__device__ constexpr auto GetMFMAInfo();
template <class data_type, index_t MPerWave, index_t NPerWave>
__device__ constexpr auto GetMFMAInfo();
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x1xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x1xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<float, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x1xf32>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x8f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x16f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<half, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 32, 64>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x2bf16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 32, 32>()
{
return mfma_info<mfma_instr::mfma_f32_32x32x4bf16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 16, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x8bf16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 16, 64>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 64, 16>()
{
return mfma_info<mfma_instr::mfma_f32_16x16x2bf16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 4, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{};
}
template <>
template <>
__device__ constexpr auto GetMFMAInfo<ushort, 8, 64>()
{
return mfma_info<mfma_instr::mfma_f32_4x4x2bf16>{};
}
template <class data_type,
index_t MPerWave,
index_t NPerWave,
......@@ -685,7 +682,10 @@ struct XdlopsGemm_t
__device__ static constexpr index_t M0() { return M0_; }
__device__ static constexpr index_t N1() { return N1_; }
__device__ static constexpr index_t N0() { return N0_; }
__device__ static constexpr index_t GetBlkSize() { return GetMFMAInfo<data_type, MPerWave, NPerWave>().num_regs_blk; }
__device__ static constexpr index_t GetBlkSize()
{
return GetMFMAInfo<data_type, MPerWave, NPerWave>().num_regs_blk;
}
__device__ static constexpr index_t GetNumBlks()
{
......@@ -726,7 +726,6 @@ struct XdlopsGemm_t
return mfma_type.num_output_blks == 1 && mfma_type.num_input_blks != 1;
}
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
......@@ -843,7 +842,8 @@ struct XdlopsGemm_t
for(index_t k = 0; k < K; ++k)
{
mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
mfma_type.run(
Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
}
}).Else([&](auto) {
......@@ -852,7 +852,8 @@ struct XdlopsGemm_t
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
mfma_type.run(
Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
}
});
......@@ -898,7 +899,7 @@ struct XdlopsGemm_t
__device__ void SetZeroXdlopsRegs(Number<Size>) const
{
#if !CK_USE_AMD_XDLOPS_EMULATE
//gcnasm_accvgpr_zero<Size>();
// gcnasm_accvgpr_zero<Size>();
#endif
}
......@@ -907,8 +908,8 @@ struct XdlopsGemm_t
{
#if !CK_USE_AMD_XDLOPS_EMULATE
constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
//gcnasm_nop<mfma_type.cycles>();
//gcnasm_accvgpr_read<Size>(p_c_thread);
// gcnasm_nop<mfma_type.cycles>();
// gcnasm_accvgpr_read<Size>(p_c_thread);
#else
(void)p_c_thread;
#endif
......
......@@ -7,7 +7,8 @@ template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x1f32(const float&, const float&, float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
__device__ void
gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 32; i++)
......@@ -17,7 +18,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const flo
}
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
__device__ void
gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
......@@ -27,7 +29,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const flo
}
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const float& reg_b, float32_t* reg_c)
__device__ void
gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
......@@ -53,12 +56,14 @@ template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x1f32(const float&, const float&, float16_t*);
template <>
__device__ void gcnasm_mfma_f32_16x16x1f32<16, 64>(const float& reg_a, const float& reg_b, float16_t* reg_c)
__device__ void
gcnasm_mfma_f32_16x16x1f32<16, 64>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x1f32<64, 16>(const float& reg_a, const float& reg_b, float16_t* reg_c)
__device__ void
gcnasm_mfma_f32_16x16x1f32<64, 16>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
}
......@@ -66,66 +71,77 @@ template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x1f32(const float& reg_a, const float& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x1f32<4, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
__device__ void
gcnasm_mfma_f32_4x4x1f32<4, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x1f32<8, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
__device__ void
gcnasm_mfma_f32_4x4x1f32<8, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x4f16(const half4_t&,
const half4_t&,
float32_t*);
__device__ void gcnasm_mfma_f32_32x32x4f16(const half4_t&, const half4_t&, float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
__device__ void
gcnasm_mfma_f32_32x32x4f16<64, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<32, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
__device__ void
gcnasm_mfma_f32_32x32x4f16<32, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64, 32>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
__device__ void
gcnasm_mfma_f32_32x32x4f16<64, 32>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_32x32x8f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
__device__ void
gcnasm_mfma_f32_32x32x8f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_16x16x16f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
__device__ void
gcnasm_mfma_f32_16x16x16f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x4f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c);
__device__ void
gcnasm_mfma_f32_16x16x4f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_16x16x4f16<16, 64>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
__device__ void
gcnasm_mfma_f32_16x16x4f16<16, 64>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x4f16<64, 16>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
__device__ void
gcnasm_mfma_f32_16x16x4f16<64, 16>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x4f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c);
__device__ void
gcnasm_mfma_f32_4x4x4f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x4f16<4, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
__device__ void
gcnasm_mfma_f32_4x4x4f16<4, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x4f16<8, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
__device__ void
gcnasm_mfma_f32_4x4x4f16<8, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
......@@ -133,54 +149,69 @@ template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x2bf16(const ushort2_t&, const ushort2_t&, float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<32, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
__device__ void gcnasm_mfma_f32_32x32x2bf16<32, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 32>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 32>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float32_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_32x32x4bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
__device__ void
gcnasm_mfma_f32_32x32x4bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_16x16x8bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
__device__ void
gcnasm_mfma_f32_16x16x8bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c);
__device__ void
gcnasm_mfma_f32_16x16x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
__device__ void gcnasm_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
__device__ void gcnasm_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t& reg_a,
const ushort2_t& reg_b,
float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c);
__device__ void
gcnasm_mfma_f32_4x4x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x2bf16<4, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
__device__ void
gcnasm_mfma_f32_4x4x2bf16<4, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x2bf16<8, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
__device__ void
gcnasm_mfma_f32_4x4x2bf16<8, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
// clang-format on
}
#endif
......@@ -31,5 +31,4 @@
#include "amd_xdlops_emulate.hpp"
#endif
#endif
......@@ -858,7 +858,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
cudaDeviceSynchronize();
......
......@@ -1048,7 +1048,6 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
cudaDeviceSynchronize();
......
......@@ -85,7 +85,8 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw<
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
......@@ -161,7 +162,6 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
cudaDeviceSynchronize();
......
......@@ -20,26 +20,495 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[])
{
using namespace ck;
#if 0
// 1x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 160;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35
constexpr index_t N = 128;
constexpr index_t C = 96;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 7x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 320;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 0
// 1x7, 17x17
constexpr index_t N = 128;
constexpr index_t C = 224;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 224;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
// 3x3, 299x299 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 299;
constexpr index_t WI = 299;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 149;
constexpr index_t WI = 149;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 17x17, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 35x35
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x3, 8x8
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 448;
constexpr index_t Y = 1;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 1>;
using RightPads = Sequence<0, 1>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 1
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 3x3, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14
constexpr index_t N = 64;
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 28x28
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 3x3, 14x14
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 56x56, stride 2
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x7, 230x230 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 230;
constexpr index_t WI = 230;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride = 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#endif
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
......@@ -133,6 +602,18 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 1
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment