Commit 036c5234 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 22995e9a 7843a8a7
...@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
...@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
......
...@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
K0PerBlock, K0PerBlock,
ConvBackwardWeightSpecialization>{}; ConvBackwardWeightSpecialization>{};
static constexpr index_t MaxScalarPerVectorFP32 = 4;
static constexpr index_t WorkspaceInOutScalarPerVector =
is_same_v<AccDataType, float>
? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32)
: CBlockTransferScalarPerVector_NWaveNPerXdl;
// Bytes per 32 lds bank: 32 * 4 bytes // Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = 128; static constexpr auto BankLength = 128;
static constexpr auto ElePerBank = BankLength / sizeof(ADataType); static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
...@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
EDataType, AccDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
...@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
BBlockLdsN1Padding, BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl, WorkspaceInOutScalarPerVector,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true, true,
true, true,
...@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
static constexpr auto MakeElementwiseInputSequence() static constexpr auto MakeElementwiseInputSequence()
{ {
return generate_sequence_v2( return generate_sequence_v2(
[&](auto) constexpr { return Number<CBlockTransferScalarPerVector_NWaveNPerXdl>{}; }, [&](auto) constexpr { return Number<WorkspaceInOutScalarPerVector>{}; },
Number<NumDTensor + 1>{}); Number<NumDTensor + 1>{});
} }
...@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N<NDimSpatial>({}, {})); using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N<NDimSpatial>({}, {}));
using CDGridDesc_M_N = decltype(concat_tuple(Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{})); using CDGridDesc_M_N = decltype(concat_tuple(Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
using DsGridPointerTuple = decltype(GetDsGridPointerTuple()); using DsGridPointerTuple = decltype(GetDsGridPointerTuple());
using CDDataTypes = decltype(concat_tuple(Tuple<const EDataType*>{}, DsGridPointerTuple{})); using CDDataTypes = decltype(concat_tuple(Tuple<const AccDataType*>{}, DsGridPointerTuple{}));
using EGridDesc_M_N = CGridDesc_M_N; using EGridDesc_M_N = CGridDesc_M_N;
static constexpr index_t ClusterLengthMPerBlock = static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
...@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
std::size_t GetWorkspaceSizeBytes() const std::size_t GetWorkspaceSizeBytes() const
{ {
return sizeof(EDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { auto launch_gemm_kernel = [&](auto has_main_k_block_loop) {
EDataType* p_c_grid = type_convert<EDataType*>(arg.p_workspace_); AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_;
...@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, ADataType,
BDataType, BDataType,
EDataType, AccDataType,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
element_wise::PassThrough, element_wise::PassThrough,
...@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}; };
auto launch_elementwise_kernel = [&]() { auto launch_elementwise_kernel = [&]() {
const EDataType* p_c_grid = type_convert<const EDataType*>(arg.p_workspace_); const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size = const index_t grid_size =
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) *
arg.Conv_G_; arg.Conv_G_;
...@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
} }
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 &&
arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0))
{ {
return false; return false;
} }
......
...@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())) ck::is_gfx103_supported() || ck::is_gfx11_supported()))
{ {
return false; return false;
} }
......
...@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_navi3_supported())) ck::is_gfx11_supported()))
{ {
return false; return false;
} }
......
...@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment