Commit a3b86965 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into lds_bypass_spilling
parents bdd0f64e fe96e8fb
...@@ -669,7 +669,15 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -669,7 +669,15 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -86,120 +86,84 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -86,120 +86,84 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static constexpr auto matrix_padder =
{ MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) {
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
} }
#ifdef ENABLE_COLMAJOR else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
} }
#endif
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
{ const auto M = a_grid_desc_m_k.GetLength(I0);
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto K = a_grid_desc_m_k.GetLength(I1);
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % K1 == 0); assert(K % K1 == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() { return transform_tensor_descriptor(
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
{ const auto N = b_grid_desc_n_k.GetLength(I0);
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
return transform_tensor_descriptor( const index_t K0 = K / K1;
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), return transform_tensor_descriptor(
make_right_pad_transform(N, PadN)), b_grid_desc_n_k,
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_pass_through_transform(N)),
} make_tuple(Sequence<1>{}, Sequence<0>{}),
else make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
} }
template <typename ELayout_> template <typename ELayout_>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto e_grid_desc_m_n = [&]() { const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
} }
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms, static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms,
...@@ -512,7 +476,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -512,7 +476,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100") if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -680,6 +680,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -680,6 +680,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
<< ">" << ">"
<< " LoopScheduler: " << " LoopScheduler: "
......
...@@ -822,7 +822,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio ...@@ -822,7 +822,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -102,12 +103,11 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -102,12 +103,11 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
} }
#ifdef ENABLE_COLMAJOR else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
} }
#endif
}(); }();
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
...@@ -150,7 +150,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -150,7 +150,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB)); make_tuple(I1, StrideB));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1)); make_tuple(StrideB, I1));
...@@ -445,7 +445,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -445,7 +445,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100") if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{ {
assert(K % K1 == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
...@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{ {
assert(K % K1 == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
...@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<< MPerXDL << ", " << MPerXDL << ", "
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">" << ">"
<< " NumPrefetch: " << " NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
......
...@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">" << ">"
<< " LoopScheduler: " << " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
......
...@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
return str.str(); return str.str();
......
...@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle ...@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -579,7 +579,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -579,7 +579,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(get_device_name() == "gfx1100") if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
...@@ -837,7 +838,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -837,7 +838,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -381,6 +381,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -381,6 +381,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
a_mtx_mraw_kraw_.emplace_back(M, K);
b_mtx_nraw_kraw_.emplace_back(N, K);
if(M == 0) if(M == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
...@@ -485,6 +488,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -485,6 +488,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
CDEElementwiseOperation c_element_op_; CDEElementwiseOperation c_element_op_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_; std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
index_t grid_size_; index_t grid_size_;
}; };
...@@ -599,7 +604,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -599,7 +604,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return false; return false;
} }
return true; bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
if constexpr(GemmSpec != GemmSpecialization::Default)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
for(index_t i = 0; i < arg.group_count_; ++i)
{
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
}
}
return supported;
} }
// polymorphic // polymorphic
...@@ -661,7 +687,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -661,7 +687,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< MPerXDL << ", " << MPerXDL << ", "
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"; << ">";
// clang-format on // clang-format on
......
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
// #include "ck/utility/get_id.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,18 +18,27 @@ struct Activation_Mul_Clamp ...@@ -17,18 +18,27 @@ struct Activation_Mul_Clamp
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{ {
float x_fp32 = ck::type_convert<float>(x); float y_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(y_fp32, y_fp32);
float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const __device__ constexpr void operator()(int32_t& y, const int32_t& x) const
{ {
// We might type_convert to int8 after lambda in someplace // CAUSION - We might type_convert to int8 in threadwise copy
float x_fp32 = ck::type_convert<float>(x); // eg. GridwiseGemmDlMultipleD_km_kn_mn
activationOp_(x_fp32, x_fp32); float y_fp32 = ck::type_convert<float>(x);
y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
__host__ constexpr void operator()(float& y, const float& x) const
{
// CAUSION - We might float in & float out in reference code
activationOp_(y, x);
y = math::clamp(requantScale_ * y, -128.f, 127.f);
} }
float requantScale_; float requantScale_;
...@@ -51,6 +61,17 @@ struct Activation_Mul2_Clamp ...@@ -51,6 +61,17 @@ struct Activation_Mul2_Clamp
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
__device__ constexpr void
operator()(int32_t& y, const int32_t& x, const float& requantScale) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
Activation activationOp_; Activation activationOp_;
}; };
...@@ -72,6 +93,17 @@ struct Add_Activation_Mul_Clamp ...@@ -72,6 +93,17 @@ struct Add_Activation_Mul_Clamp
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
float requantScale_; float requantScale_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -92,6 +124,17 @@ struct Add_Activation_Mul2_Clamp ...@@ -92,6 +124,17 @@ struct Add_Activation_Mul2_Clamp
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
__host__ __device__ constexpr void
operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int32_t>(y_fp32);
}
Activation activationOp_; Activation activationOp_;
}; };
......
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
namespace ck { namespace ck {
/**
* @brief Gridwise gemm + softmax + gemm fusion
*
*/
template <typename FloatAB, template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
......
...@@ -185,8 +185,10 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -185,8 +185,10 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
return b_grid_desc_k0_n0_n1_k1; return b_grid_desc_k0_n0_n1_k1;
} }
// E desc for destination in blockwise copy
template <typename CGridDesc_M_N_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -238,19 +240,19 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -238,19 +240,19 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
template <typename DsGridDesc_M0_M10_M11_N0_N10_N11, template <typename DsGridDesc_M0_M10_M11_N0_N10_N11,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop,
typename Block2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const AElementwiseOperation&, const AElementwiseOperation&,
const BElementwiseOperation&, const BElementwiseOperation&,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
...@@ -399,8 +401,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -399,8 +401,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
constexpr auto b_block_aligned_space_size = math::integer_least_multiple( constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = static_cast<FloatAB*>(p_shared_block);
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double =
static_cast<FloatAB*>(p_shared_block) + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment