Commit d44f6660 authored by aska-0096's avatar aska-0096
Browse files

deprecate inline asm wmma

parent 823c8801
...@@ -39,7 +39,7 @@ using S = ck::Sequence<Is...>; ...@@ -39,7 +39,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
...@@ -56,24 +56,24 @@ using DeviceConvFwdInstance = ...@@ -56,24 +56,24 @@ using DeviceConvFwdInstance =
64, // MPerBlock 64, // MPerBlock
64, // NPerBlock 64, // NPerBlock
64, // KPerBlock 64, // KPerBlock
4, // K1 8, // K1
16, // MPerWMMA 16, // MPerWMMA
16, // NPerWMMA 16, // NPerWMMA
4, // MRepeat 4, // MRepeat
1, // NRepeat 1, // NRepeat
S<4, 8, 4>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_AK1 8, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM true, // ABlockLdsExtraM
S<4, 8, 4>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_BK1 8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
1, 1,
1, 1,
...@@ -278,9 +278,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) ...@@ -278,9 +278,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch(conv_param.num_dim_spatial_) switch(conv_param.num_dim_spatial_)
{ {
case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
} }
return false; return false;
......
...@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = using DeviceMHAFactory =
std::tuple<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
...@@ -99,7 +100,8 @@ using DeviceGemmInstance = ...@@ -99,7 +100,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
64, // LPerBlock 64, // LPerBlock
64, // KPerBlock 64, // KPerBlock
8, // K1 8, // AK1
8, // BK1
// Gemm 1 // Gemm 1
64, // NPerBlock 64, // NPerBlock
64, // LTilePerBlock 64, // LTilePerBlock
...@@ -136,8 +138,8 @@ using DeviceGemmInstance = ...@@ -136,8 +138,8 @@ using DeviceGemmInstance =
2, // CShuffleNWmmaPerWavePerShuffle 2, // CShuffleNWmmaPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec> // MaskingSpecialization
>;
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
......
...@@ -70,7 +70,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial ...@@ -70,7 +70,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
// clang-format off // clang-format off
// #define CK_MHA_USE_WAVE_1 // #define CK_MHA_USE_WAVE_1
// #define CK_MHA_USE_WAVE_2 // #define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4 // #define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8 #define CK_MHA_USE_WAVE_8
using DeviceMHAFactory = using DeviceMHAFactory =
std::tuple< std::tuple<
......
...@@ -175,8 +175,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -175,8 +175,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = false; static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = static constexpr auto AEnableLds =
AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1); AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1);
......
...@@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3); c3);
} }
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
#if defined(__gfx11__)
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
#else
ignore = a;
ignore = b;
ignore = c;
#endif
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -21,10 +21,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -21,10 +21,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment