Commit 036c5234 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 22995e9a 7843a8a7
......@@ -658,27 +658,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
{
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
<< "}" << std::endl;
std::cout << "arg.reduce_grid_desc_m_{ "
<< arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl;
}
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......
......@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DEBUG_LOG
arg.Print();
#endif
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
arg.Print();
}
if(!ck::is_xdl_supported())
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -627,7 +627,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim))
{
return false;
......@@ -637,7 +638,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim))
{
return false;
......@@ -648,7 +650,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim))
{
valid_ds_access = false;
......@@ -662,7 +665,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim))
{
return false;
......
......@@ -516,26 +516,27 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
}
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
......
......@@ -644,7 +644,7 @@ struct
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
......@@ -664,9 +664,7 @@ struct
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
......@@ -684,7 +682,6 @@ struct
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
......@@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
......@@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
......@@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
......@@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
.GetLength(I5)
<< "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......@@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{
......
......@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
......@@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
......@@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5)
<< " ) " << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
......@@ -1393,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported()))
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported()))
{
return false;
}
......
......@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
......@@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
......
......@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
......@@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
......@@ -349,7 +349,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
......@@ -536,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
}
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported())
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
......@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg)
{
if(ck::is_navi2_supported() || ck::is_navi3_supported())
if(ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
return GridwiseGemm::CheckValidity(karg);
}
......
......@@ -10,110 +10,30 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename AsPointer,
typename BsPointer,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AsGridDesc_AK0_M_AK1,
typename BsGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_as_grid;
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = as_grid_desc_ak0_m_ak1;
ignore = bs_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename CLayout,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -132,59 +52,56 @@ template <typename AsLayout,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
CLayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleABD_Xdl_CShuffle;
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ComputeDataType = EDataType;
using ALayout = remove_cvref_t<tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<tuple_element_t<0, BsLayout>>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle<
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
AsDataType,
BsDataType,
ComputeDataType,
AccDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
EDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
......@@ -213,360 +130,476 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
// desc for problem definition
using AsGridDesc_M_K =
remove_cvref_t<decltype(GridwiseGemm::template MakeAsGridDescriptor_M_K<AsLayout, GemmSpec>(
{}, {}, {}))>;
using BsGridDesc_N_K =
remove_cvref_t<decltype(GridwiseGemm::template MakeBsGridDescriptor_N_K<BsLayout, GemmSpec>(
{}, {}, {}))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
using EGridDesc_M_N =
decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(1, 1, 1));
// desc for blockwise copy
using AsGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(
AsGridDesc_M_K{}))>;
using BsGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(
BsGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
{
Argument(std::array<const void*, NumATensor> p_as_grid,
std::array<const void*, NumBTensor> p_bs_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
std::array<index_t, NumATensor> StrideAs,
std::array<index_t, NumBTensor> StrideBs,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_as_grid_{},
p_bs_grid_{},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
as_grid_desc_m_k_{},
bs_grid_desc_n_k_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
MRaw, NRaw, StrideE)},
as_grid_desc_ak0_m_ak1_{},
bs_grid_desc_bk0_n_bk1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw}
{
// populate pointer, desc for As
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
// A pointer
p_as_grid_(i) = static_cast<const ADataType*>(p_as_grid[i]);
// A desc
as_grid_desc_m_k_(i) =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
MRaw, KRaw, StrideAs[i]);
});
// populate pointer, desc for Bs
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
// B pointer
p_bs_grid_(i) = static_cast<const BDataType*>(p_bs_grid[i]);
// B desc
bs_grid_desc_n_k_(i) =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
NRaw, KRaw, StrideBs[i]);
});
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_,
bs_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
as_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
bs_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
}
}
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// private:
// pointers
typename GridwiseGemm::AsGridPointer p_as_grid_;
typename GridwiseGemm::BsGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
AsGridDesc_M_K as_grid_desc_m_k_;
BsGridDesc_N_K bs_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_;
BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
};
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
arg.bs_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_abd_xdl_cshuffle<
GridwiseGemm,
typename GridwiseGemm::AsGridPointer,
typename GridwiseGemm::BsGridPointer,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AsGridDesc_AK0_M_AK1,
DeviceOp::BsGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_as_grid_,
arg.p_bs_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.as_grid_desc_ak0_m_ak1_,
arg.bs_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1);
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
const auto Run = [&](const auto& kernel) {
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
// check vector load/store
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
};
bool all_valid = true;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
#if 0
if(arg.KBatch > 1)
{
all_valid = false;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
else
#endif
{
all_valid = false;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(ABlockTransferSrcScalarPerVector != 1)
#if 0
if(arg.KBatch > 1)
{
all_valid = false;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
else
#endif
{
all_valid = false;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
#if 0
if(arg.KBatch > 1)
{
all_valid = false;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(BBlockTransferSrcScalarPerVector != 1)
#if 0
if(arg.KBatch > 1)
{
all_valid = false;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
});
// check vector load of Ds
// only support RowMajor for now
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
// check vector store of E
// only support RowMajor for now
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
all_valid = false;
}
}
else
{
all_valid = false;
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
#if 0
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
if(!all_valid)
{
return false;
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
arg.bs_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
......@@ -588,8 +621,27 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
CElementwiseOperation c_element_op)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
static_assert(is_same<ALayout_, ALayout>::value, "");
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
static_assert(is_same<BLayout_, BLayout>::value, "");
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout_ = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
static_assert(is_same<DLayout_, CLayout>::value, "");
});
return Argument{p_as,
p_bs,
p_ds,
......@@ -601,29 +653,29 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
StrideBs,
StrideDs,
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op};
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
std::array<ck::index_t, NumATensor> StrideAs,
std::array<ck::index_t, NumBTensor> StrideBs,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
std::array<ck::index_t, NumATensor> StrideAs,
std::array<ck::index_t, NumBTensor> StrideBs,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(p_as,
p_bs,
......@@ -636,9 +688,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
StrideBs,
StrideDs,
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
c_element_op);
}
// polymorphic
......@@ -652,35 +705,41 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMultipleABD_Xdl_CShuffle"
str << "DeviceGemmXdlUniversal"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
......
......@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
......@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -510,7 +510,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
......@@ -528,7 +528,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
<< "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......
......@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment