Unverified Commit 47b3e10b authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into conv_dlops/quantization

parents c6683eea c10a6e82
......@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
......
......@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">";
// clang-format on
......
......@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization)
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
return str.str();
......
......@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization)
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1
<< ">";
// clang-format on
......
......@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization)
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">";
// clang-format on
......
......@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
......
......@@ -838,7 +838,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">";
// clang-format on
......
......@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
......
......@@ -688,6 +688,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
......
......@@ -25,7 +25,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
......@@ -46,7 +46,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
......@@ -71,7 +71,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
#else
......@@ -95,7 +95,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
......@@ -117,7 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a,
......@@ -145,7 +145,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
......@@ -166,7 +166,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
......@@ -191,7 +191,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
#else
......@@ -215,7 +215,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
......@@ -237,7 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx11__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment