"vscode:/vscode.git/clone" did not exist on "7b01dbee0f878f0d6a54da3566401d8441a48233"
Commit 9b3c4ac4 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 1d784873 7843a8a7
...@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>)) is_same_v<AccDataType, int32_t>))
......
...@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
...@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
......
...@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
K0PerBlock, K0PerBlock,
ConvBackwardWeightSpecialization>{}; ConvBackwardWeightSpecialization>{};
static constexpr index_t MaxScalarPerVectorFP32 = 4;
static constexpr index_t WorkspaceInOutScalarPerVector =
is_same_v<AccDataType, float>
? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32)
: CBlockTransferScalarPerVector_NWaveNPerXdl;
// Bytes per 32 lds bank: 32 * 4 bytes // Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = 128; static constexpr auto BankLength = 128;
static constexpr auto ElePerBank = BankLength / sizeof(ADataType); static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
...@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
EDataType, AccDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
...@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
BBlockLdsN1Padding, BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl, WorkspaceInOutScalarPerVector,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true, true,
true, true,
...@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
static constexpr auto MakeElementwiseInputSequence() static constexpr auto MakeElementwiseInputSequence()
{ {
return generate_sequence_v2( return generate_sequence_v2(
[&](auto) constexpr { return Number<CBlockTransferScalarPerVector_NWaveNPerXdl>{}; }, [&](auto) constexpr { return Number<WorkspaceInOutScalarPerVector>{}; },
Number<NumDTensor + 1>{}); Number<NumDTensor + 1>{});
} }
...@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N<NDimSpatial>({}, {})); using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N<NDimSpatial>({}, {}));
using CDGridDesc_M_N = decltype(concat_tuple(Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{})); using CDGridDesc_M_N = decltype(concat_tuple(Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
using DsGridPointerTuple = decltype(GetDsGridPointerTuple()); using DsGridPointerTuple = decltype(GetDsGridPointerTuple());
using CDDataTypes = decltype(concat_tuple(Tuple<const EDataType*>{}, DsGridPointerTuple{})); using CDDataTypes = decltype(concat_tuple(Tuple<const AccDataType*>{}, DsGridPointerTuple{}));
using EGridDesc_M_N = CGridDesc_M_N; using EGridDesc_M_N = CGridDesc_M_N;
static constexpr index_t ClusterLengthMPerBlock = static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
...@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
std::size_t GetWorkspaceSizeBytes() const std::size_t GetWorkspaceSizeBytes() const
{ {
return sizeof(EDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { auto launch_gemm_kernel = [&](auto has_main_k_block_loop) {
EDataType* p_c_grid = type_convert<EDataType*>(arg.p_workspace_); AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_;
...@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, ADataType,
BDataType, BDataType,
EDataType, AccDataType,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
element_wise::PassThrough, element_wise::PassThrough,
...@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}; };
auto launch_elementwise_kernel = [&]() { auto launch_elementwise_kernel = [&]() {
const EDataType* p_c_grid = type_convert<const EDataType*>(arg.p_workspace_); const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size = const index_t grid_size =
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) *
arg.Conv_G_; arg.Conv_G_;
...@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
} }
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 &&
arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0))
{ {
return false; return false;
} }
......
...@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())) ck::is_gfx103_supported() || ck::is_gfx11_supported()))
{ {
return false; return false;
} }
......
...@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_navi3_supported())) ck::is_gfx11_supported()))
{ {
return false; return false;
} }
......
...@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -553,24 +553,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -553,24 +553,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" {
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << ", "
<< std::endl; << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1)
<< ", "
std::cout << ", arg.b_grid_desc_k0_n_k1_{" << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2)
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " << "}" << std::endl;
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}" std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< std::endl; << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", "
std::cout << ", arg.e_grid_desc_m_n_{ " << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1)
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", " << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}" << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2)
<< std::endl; << "}" << std::endl;
#endif
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
...@@ -668,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -668,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
} }
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported())
{ {
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
......
...@@ -467,18 +467,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -467,18 +467,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end; gemm_kernel_args_[i].block_end_ = block_end;
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
index_t tiles = (block_end - block_start) / K_BATCH; {
std::cout << "block_start: " << block_start << "\n" index_t tiles = (block_end - block_start) / K_BATCH;
<< "block_end: " << block_end << "\n" std::cout << "block_start: " << block_start << "\n"
<< "tiles: " << tiles << std::endl << "block_end: " << block_end << "\n"
<< std::endl; << "tiles: " << tiles << std::endl
<< std::endl;
std::cout << "KPadded: " << karg.KPadded << std::endl
<< "K0Padded: " << karg.K0Padded << std::endl std::cout << "KPadded: " << karg.KPadded << std::endl
<< "KBatch: " << karg.k_batch << std::endl << "K0Padded: " << karg.K0Padded << std::endl
<< "grid_size_: " << karg.KPadded << std::endl; << "KBatch: " << karg.k_batch << std::endl
#endif << "grid_size_: " << karg.KPadded << std::endl;
}
} }
} }
...@@ -493,12 +494,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -493,12 +494,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
arg.karg_.p_c_grid = p_workspace + offset; arg.karg_.p_c_grid = p_workspace + offset;
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
offset += tiles * MPerBlock * NPerBlock; offset += tiles * MPerBlock * NPerBlock;
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "block_start: " << arg.block_start_ << "\n" {
<< "block_end: " << arg.block_end_ << "\n" std::cout << "block_start: " << arg.block_start_ << "\n"
<< "tiles: " << tiles << "\n" << "block_end: " << arg.block_end_ << "\n"
<< "offset: " << offset << std::endl; << "tiles: " << tiles << "\n"
#endif << "offset: " << offset << std::endl;
}
} }
} }
...@@ -816,11 +818,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -816,11 +818,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_) arg.skipped_group_count_) != arg.group_count_)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "The group count is not equal to sum of skipped groups " {
"and kernel args size!" std::cout << "The group count is not equal to sum of skipped groups "
<< std::endl; "and kernel args size!"
#endif // DEBUG_LOG << std::endl;
}
return false; return false;
} }
...@@ -832,11 +835,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -832,11 +835,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
if(not group_arg_valid) if(not group_arg_valid)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "[" << __func__ << "] group id: " << i {
<< " has invalid GridwiseGemm settings!" << std::endl; std::cout << "[" << __func__ << "] group id: " << i
gemm_arg.Print(); << " has invalid GridwiseGemm settings!" << std::endl;
#endif // DEBUG_LOG gemm_arg.Print();
}
} }
supported = supported && group_arg_valid; supported = supported && group_arg_valid;
} }
......
...@@ -375,7 +375,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -375,7 +375,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
std::vector<const void*>& /* p_Bs */, std::vector<const void*>& /* p_Bs */,
std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */, std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */,
std::vector<void*>& /* p_Es */, std::vector<void*>& /* p_Es */,
std::vector<GemmDesc>& gemm_descs, const std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
...@@ -620,11 +620,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -620,11 +620,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>( GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
M, N, K))) M, N, K)))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," << K {
<< "] are not supported by current template parameters!" std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; << K << "] are not supported by current template parameters!"
#endif << " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
}
supported = false; supported = false;
} }
} }
...@@ -641,7 +643,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -641,7 +643,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds, std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es, std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_elementwise_op, AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op, BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op) CDEElementwiseOperation cde_elementwise_op)
......
...@@ -514,28 +514,29 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -514,28 +514,29 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{" {
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< ", " << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< ", " << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2) << ", "
<< "}"; << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
<< "}";
std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0) std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
<< ", " << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< ", " << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2) << ", "
<< "}"; << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
<< "}";
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", " std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}" << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< std::endl; << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
#endif << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_, arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
......
...@@ -529,11 +529,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -529,11 +529,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_) arg.skipped_group_count_) != arg.group_count_)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "The group count is not equal to sum of skipped groups " {
"and kernel args size!" std::cout << "The group count is not equal to sum of skipped groups "
<< std::endl; "and kernel args size!"
#endif // DEBUG_LOG << std::endl;
}
return false; return false;
} }
...@@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool group_arg_valid = GridwiseGemm::CheckValidity(a); bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid) if(not group_arg_valid)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "[" << __func__ << "] group id: " << i {
<< " has invalid GridwiseGemm settings!" << std::endl; std::cout << "[" << __func__ << "] group id: " << i
a.Print(); << " has invalid GridwiseGemm settings!" << std::endl;
#endif // DEBUG_LOG a.Print();
}
} }
supported = supported && group_arg_valid; supported = supported && group_arg_valid;
} }
......
...@@ -596,7 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -596,7 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma ...@@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma ...@@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo ...@@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
}; };
// Grouped Rows of column-vectors WGP mapping // Grouped Rows of column-vectors WGP mapping
// Optimized for MI300-like multipe-die chip // Optimized for gfx94x-like multipe-die chip
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock> template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
......
...@@ -935,12 +935,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -935,12 +935,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " {
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< std::endl; << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -952,12 +952,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -952,12 +952,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " {
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< std::endl; << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -971,12 +971,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -971,12 +971,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto K_t = karg.KBatch * KPerBlock; auto K_t = karg.KBatch * KPerBlock;
if(!(karg.K % K_t == 0)) if(!(karg.K % K_t == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " {
<< karg.K << " " << __FILE__ << ":" << __LINE__ std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< ", in function: " << __func__ << std::endl; << karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -995,13 +995,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -995,13 +995,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg K (" << karg.K {
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" std::cout << "Arg K (" << karg.K
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< __LINE__ << ", in function: " << __func__ << std::endl; << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1009,13 +1009,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1009,13 +1009,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg M (" << karg.M {
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" std::cout << "Arg M (" << karg.M
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< __LINE__ << ", in function: " << __func__ << std::endl; << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1024,13 +1024,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1024,13 +1024,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg N (" << karg.N {
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" std::cout << "Arg N (" << karg.N
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< __LINE__ << ", in function: " << __func__ << std::endl; << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1038,13 +1038,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1038,13 +1038,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.K % BBlockTransferSrcScalarPerVector != 0) if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg K (" << karg.K {
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" std::cout << "Arg K (" << karg.K
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< __LINE__ << ", in function: " << __func__ << std::endl; << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1053,14 +1053,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1053,14 +1053,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg N (" << karg.N {
<< ") value is not a multiple of " std::cout << "Arg N (" << karg.N
"CShuffleBlockTransferScalarPerVector_NPerBlock (" << ") value is not a multiple of "
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ "CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
#endif // DEBUG_LOG << std::endl;
}
return false; return false;
} }
} }
...@@ -1068,25 +1069,26 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1068,25 +1069,26 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg M (" << karg.M {
<< ") value is not a multiple of " std::cout << "Arg M (" << karg.M
"CShuffleBlockTransferScalarPerVector_NPerBlock (" << ") value is not a multiple of "
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ "CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
#endif // DEBUG_LOG << std::endl;
}
return false; return false;
} }
} }
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ << ":" {
<< __LINE__ << ", in function: " << __func__ << std::endl; std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
if(karg.KBatch > 1) if(karg.KBatch > 1)
{ {
return false; return false;
......
...@@ -1113,12 +1113,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1113,12 +1113,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " {
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< std::endl; << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1130,12 +1130,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1130,12 +1130,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " {
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< std::endl; << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1149,12 +1149,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1149,12 +1149,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto K_t = karg.KBatch * KPerBlock; auto K_t = karg.KBatch * KPerBlock;
if(!(karg.K % K_t == 0)) if(!(karg.K % K_t == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " {
<< karg.K << " " << __FILE__ << ":" << __LINE__ std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< ", in function: " << __func__ << std::endl; << karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1173,13 +1173,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1173,13 +1173,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg K (" << karg.K {
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" std::cout << "Arg K (" << karg.K
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< __LINE__ << ", in function: " << __func__ << std::endl; << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1187,13 +1187,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1187,13 +1187,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg M (" << karg.M {
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" std::cout << "Arg M (" << karg.M
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< __LINE__ << ", in function: " << __func__ << std::endl; << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1202,13 +1202,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1202,13 +1202,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg N (" << karg.N {
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" std::cout << "Arg N (" << karg.N
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< __LINE__ << ", in function: " << __func__ << std::endl; << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1216,13 +1216,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1216,13 +1216,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.K % BBlockTransferSrcScalarPerVector != 0) if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg K (" << karg.K {
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" std::cout << "Arg K (" << karg.K
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< __LINE__ << ", in function: " << __func__ << std::endl; << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -1231,14 +1231,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1231,14 +1231,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg N (" << karg.N {
<< ") value is not a multiple of " std::cout << "Arg N (" << karg.N
"CShuffleBlockTransferScalarPerVector_NPerBlock (" << ") value is not a multiple of "
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ "CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
#endif // DEBUG_LOG << std::endl;
}
return false; return false;
} }
} }
...@@ -1246,14 +1247,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1246,14 +1247,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg M (" << karg.M {
<< ") value is not a multiple of " std::cout << "Arg M (" << karg.M
"CShuffleBlockTransferScalarPerVector_NPerBlock (" << ") value is not a multiple of "
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ "CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
#endif // DEBUG_LOG << std::endl;
}
return false; return false;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment