Commit 9b3c4ac4 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 1d784873 7843a8a7
......@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
......@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
......@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......
......@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
K0PerBlock,
ConvBackwardWeightSpecialization>{};
static constexpr index_t MaxScalarPerVectorFP32 = 4;
static constexpr index_t WorkspaceInOutScalarPerVector =
is_same_v<AccDataType, float>
? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32)
: CBlockTransferScalarPerVector_NWaveNPerXdl;
// Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = 128;
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
......@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
ADataType,
BDataType,
AccDataType,
EDataType,
AccDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
......@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
WorkspaceInOutScalarPerVector,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true,
......@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
static constexpr auto MakeElementwiseInputSequence()
{
return generate_sequence_v2(
[&](auto) constexpr { return Number<CBlockTransferScalarPerVector_NWaveNPerXdl>{}; },
[&](auto) constexpr { return Number<WorkspaceInOutScalarPerVector>{}; },
Number<NumDTensor + 1>{});
}
......@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N<NDimSpatial>({}, {}));
using CDGridDesc_M_N = decltype(concat_tuple(Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
using DsGridPointerTuple = decltype(GetDsGridPointerTuple());
using CDDataTypes = decltype(concat_tuple(Tuple<const EDataType*>{}, DsGridPointerTuple{}));
using CDDataTypes = decltype(concat_tuple(Tuple<const AccDataType*>{}, DsGridPointerTuple{}));
using EGridDesc_M_N = CGridDesc_M_N;
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
......@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
std::size_t GetWorkspaceSizeBytes() const
{
return sizeof(EDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
}
const ADataType* p_a_grid_;
......@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
auto launch_gemm_kernel = [&](auto has_main_k_block_loop) {
EDataType* p_c_grid = type_convert<EDataType*>(arg.p_workspace_);
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_;
......@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
GridwiseGemm,
ADataType,
BDataType,
EDataType,
AccDataType,
OutElementwiseOperation,
InElementwiseOperation,
element_wise::PassThrough,
......@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
};
auto launch_elementwise_kernel = [&]() {
const EDataType* p_c_grid = type_convert<const EDataType*>(arg.p_workspace_);
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size =
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) *
arg.Conv_G_;
......@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
// vector store C matrix into global memory
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 &&
arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0))
{
return false;
}
......
......@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()))
ck::is_gfx103_supported() || ck::is_gfx11_supported()))
{
return false;
}
......
......@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace ctc = tensor_layout::convolution;
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported()))
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported()))
{
return false;
}
......
......@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution;
// check device
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -553,24 +553,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"
<< std::endl;
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2)
<< "}" << std::endl;
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"
<< std::endl;
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2)
<< "}" << std::endl;
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
}
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
......@@ -668,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
......
......@@ -467,7 +467,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
index_t tiles = (block_end - block_start) / K_BATCH;
std::cout << "block_start: " << block_start << "\n"
<< "block_end: " << block_end << "\n"
......@@ -478,7 +479,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<< "K0Padded: " << karg.K0Padded << std::endl
<< "KBatch: " << karg.k_batch << std::endl
<< "grid_size_: " << karg.KPadded << std::endl;
#endif
}
}
}
......@@ -493,12 +494,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
arg.karg_.p_c_grid = p_workspace + offset;
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
offset += tiles * MPerBlock * NPerBlock;
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "block_start: " << arg.block_start_ << "\n"
<< "block_end: " << arg.block_end_ << "\n"
<< "tiles: " << tiles << "\n"
<< "offset: " << offset << std::endl;
#endif
}
}
}
......@@ -816,11 +818,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG
}
return false;
}
......@@ -832,11 +835,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
if(not group_arg_valid)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
gemm_arg.Print();
#endif // DEBUG_LOG
}
}
supported = supported && group_arg_valid;
}
......
......@@ -375,7 +375,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
std::vector<const void*>& /* p_Bs */,
std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */,
std::vector<void*>& /* p_Es */,
std::vector<GemmDesc>& gemm_descs,
const std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
......@@ -620,11 +620,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
M, N, K)))
{
#if DEBUG_LOG
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," << K
<< "] are not supported by current template parameters!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
#endif
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
<< K << "] are not supported by current template parameters!"
<< " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
}
supported = false;
}
}
......@@ -641,7 +643,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op)
......
......@@ -514,7 +514,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", "
......@@ -535,7 +536,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
}
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
......
......@@ -529,11 +529,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG
}
return false;
}
......@@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
a.Print();
#endif // DEBUG_LOG
}
}
supported = supported && group_arg_valid;
}
......
......@@ -596,7 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
};
// Grouped Rows of column-vectors WGP mapping
// Optimized for MI300-like multipe-die chip
// Optimized for gfx94x-like multipe-die chip
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
......
......@@ -935,12 +935,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(!(karg.M % MPerBlock == 0))
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -952,12 +952,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(!(karg.N % NPerBlock == 0))
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -971,12 +971,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto K_t = karg.KBatch * KPerBlock;
if(!(karg.K % K_t == 0))
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -995,13 +995,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1009,13 +1009,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1024,13 +1024,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1038,13 +1038,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1053,14 +1053,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
......@@ -1068,25 +1069,26 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
#if DEBUG_LOG
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
if(karg.KBatch > 1)
{
return false;
......
......@@ -1113,12 +1113,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(!(karg.M % MPerBlock == 0))
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1130,12 +1130,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(!(karg.N % NPerBlock == 0))
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1149,12 +1149,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto K_t = karg.KBatch * KPerBlock;
if(!(karg.K % K_t == 0))
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1173,13 +1173,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1187,13 +1187,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1202,13 +1202,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1216,13 +1216,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
}
return false;
}
}
......@@ -1231,14 +1231,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
......@@ -1246,14 +1247,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment