Commit 300337cd authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'develop' into jizhan/reduce_threadwise_multi_d

parents f306d02e 02fa2c29
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmTileLoopKernelArguments
{
__host__ __device__
GroupedGemmTileLoopKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported())
{ {
bool pass = true; bool pass = true;
pass = pass && arg.K_ % K1 == 0; pass = pass && arg.K_ % K1 == 0;
......
...@@ -587,30 +587,31 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -587,30 +587,31 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
BatchStrideD1s, BatchStrideD1s,
BatchStrideE1} BatchStrideE1}
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " {
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", " << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", " << b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
<< d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0)
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " << ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
<< std::endl; << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " << std::endl;
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
#endif << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>; using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
......
...@@ -658,27 +658,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO ...@@ -658,27 +658,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; {
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{" std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) std::cout << "arg.reduce_grid_desc_m_{ "
<< "}" << std::endl; << arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl;
}
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
......
...@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
arg.Print(); {
#endif arg.Print();
}
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_cgemm.hpp" #include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -80,42 +80,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -80,42 +80,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto MPerThread = Number<4>{}; static constexpr index_t MPerThread =
MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t NPerThread =
NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto AScalarPerVector = Number<4>{}; static constexpr auto AScalarPerVector = Number<4>{};
static constexpr auto BScalarPerVector = Number<4>{}; static constexpr auto BScalarPerVector = Number<4>{};
static constexpr auto CScalarPerVector = Number<4>{}; static constexpr auto CScalarPerVector = Number<4>{};
template <typename Desc_M> template <typename Desc_M_N>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) static auto PadDescriptor_M_N(Desc_M_N desc)
{ {
const auto M = desc_m.GetLength(I0); const auto M = desc.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread; const auto N = desc.GetLength(I1);
const auto pad = math::integer_least_multiple(M, loop_step) - M; const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M;
const auto desc_m_pad = const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N;
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(M, pad)), const auto padded_desc = transform_tensor_descriptor(
make_tuple(Sequence<0>{}), desc,
make_tuple(Sequence<0>{})); make_tuple(make_right_pad_transform(M, pad_M), make_right_pad_transform(N, pad_N)),
return desc_m_pad; make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return padded_desc;
} }
static auto MakeDescriptor_M(const std::vector<index_t>& lengths, static auto MakeDescriptor_M_N(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides, const std::vector<index_t>& strides)
index_t gridSize,
index_t blockSize)
{ {
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{}); auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{}); auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
// nd desc - [s0, s1, s2, ...] // nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
const auto desc_m = transform_tensor_descriptor( return PadDescriptor_M_N(desc);
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
} }
// GridwiseGemm // GridwiseGemm
...@@ -166,7 +165,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -166,7 +165,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1}));
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem
...@@ -195,17 +194,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -195,17 +194,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_grid_imag{p_c_grid_imag_}, p_c_grid_imag{p_c_grid_imag_},
p_aux_grid{p_workspace} p_aux_grid{p_workspace}
{ {
const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_));
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
c_grid_desc_m = c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1});
DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
c_grid_desc_m = c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_});
DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
} }
p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_); p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
...@@ -220,7 +215,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -220,7 +215,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_imag; CDataType* p_c_grid_imag;
CDataType* p_aux_grid; CDataType* p_aux_grid;
CDataType* p_aux_2_grid; CDataType* p_aux_2_grid;
CGridDesc_M c_grid_desc_m; CGridDesc_M_N c_grid_desc_m_n;
}; };
// Invoker // Invoker
...@@ -248,40 +243,63 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -248,40 +243,63 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using Subtract = ck::tensor_operation::element_wise::Subtract; using Subtract = ck::tensor_operation::element_wise::Subtract;
using GridwiseBinAdd = using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
Tuple<CGridDesc_M>, using GridwiseBinAdd = GridwiseElementwise<Tuple<CGridDesc_M_N, CGridDesc_M_N>,
Tuple<const CDataType*, const CDataType*>, Tuple<CGridDesc_M_N>,
Tuple<CDataType*>, Tuple<const CDataType*, const CDataType*>,
Add, Tuple<CDataType*>,
MPerThread, Block2TileMap,
Sequence<AScalarPerVector, BScalarPerVector>, Add,
Sequence<CScalarPerVector>>; BlockSize,
MPerBlock,
NPerBlock,
MPerThread,
NPerThread,
Sequence<0, 1>,
Sequence<AScalarPerVector, BScalarPerVector>,
Sequence<CScalarPerVector>,
I1,
I1>;
using GridwiseBinSubtract = using GridwiseBinSubtract =
GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>, GridwiseElementwise<Tuple<CGridDesc_M_N, CGridDesc_M_N>,
Tuple<CGridDesc_M>, Tuple<CGridDesc_M_N>,
Tuple<const CDataType*, const CDataType*>, Tuple<const CDataType*, const CDataType*>,
Tuple<CDataType*>, Tuple<CDataType*>,
Subtract, Block2TileMap,
MPerThread, Subtract,
Sequence<AScalarPerVector, BScalarPerVector>, BlockSize,
Sequence<CScalarPerVector>>; MPerBlock,
NPerBlock,
const auto add_kernel = kernel_elementwise_1d<GridwiseBinAdd, MPerThread,
Tuple<CGridDesc_M, CGridDesc_M>, NPerThread,
Tuple<CGridDesc_M>, Sequence<0, 1>,
Tuple<const CDataType*, const CDataType*>, Sequence<AScalarPerVector, BScalarPerVector>,
Tuple<CDataType*>, Sequence<CScalarPerVector>,
Add>; I1,
I1>;
const index_t M = arg.c_grid_desc_m_n.GetLength(I0);
const index_t N = arg.c_grid_desc_m_n.GetLength(I1);
const auto block_2_tile_map = Block2TileMap(M, N);
const auto add_kernel = kernel_elementwise<GridwiseBinAdd,
Tuple<CGridDesc_M_N, CGridDesc_M_N>,
Tuple<CGridDesc_M_N>,
Tuple<const CDataType*, const CDataType*>,
Tuple<CDataType*>,
Block2TileMap,
Add>;
const auto subtract_kernel = const auto subtract_kernel =
kernel_elementwise_1d<GridwiseBinSubtract, kernel_elementwise<GridwiseBinSubtract,
Tuple<CGridDesc_M, CGridDesc_M>, Tuple<CGridDesc_M_N, CGridDesc_M_N>,
Tuple<CGridDesc_M>, Tuple<CGridDesc_M_N>,
Tuple<const CDataType*, const CDataType*>, Tuple<const CDataType*, const CDataType*>,
Tuple<CDataType*>, Tuple<CDataType*>,
Subtract>; Block2TileMap,
Subtract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
...@@ -318,11 +336,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -318,11 +336,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
make_tuple(arg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m_n),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid), make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)), const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_real), make_tuple(arg.p_c_grid_real),
block_2_tile_map,
Subtract{}); Subtract{});
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -352,11 +371,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -352,11 +371,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
make_tuple(arg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m_n),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid), make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)), const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_imag), make_tuple(arg.p_c_grid_imag),
block_2_tile_map,
Add{}); Add{});
} }
else else
...@@ -394,11 +414,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -394,11 +414,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
make_tuple(arg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m_n),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid), make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)), const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_real), make_tuple(arg.p_c_grid_real),
block_2_tile_map,
Subtract{}); Subtract{});
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -428,11 +449,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -428,11 +449,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
make_tuple(arg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m_n),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid), make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)), const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_imag), make_tuple(arg.p_c_grid_imag),
block_2_tile_map,
Add{}); Add{});
} }
......
...@@ -663,7 +663,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -663,7 +663,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_a_access_dim_k = const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim)) if(!((valid_a_vector_size && valid_a_access_dim) ||
ABlockTransferSrcScalarPerVector == 1))
{ {
valid_as_access = false; valid_as_access = false;
} }
...@@ -682,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -682,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_b_access_dim_k = const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim)) if(!((valid_b_vector_size && valid_b_access_dim) ||
BBlockTransferSrcScalarPerVector == 1))
{ {
valid_bs_access = false; valid_bs_access = false;
} }
...@@ -698,7 +700,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -698,7 +700,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension. // Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim)) if(!((valid_d_vector_size && valid_d_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{ {
valid_ds_access = false; valid_ds_access = false;
} }
...@@ -712,7 +715,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -712,7 +715,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension. // Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_; const bool valid_e_access_dim = arg.e_nz_consecutive_;
if(!(valid_e_vector_size && valid_e_access_dim)) if(!((valid_e_vector_size && valid_e_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -53,8 +53,7 @@ __global__ void ...@@ -53,8 +53,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -627,7 +626,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -627,7 +626,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim)) if(!(valid_a_vector_size && valid_a_access_dim))
{ {
return false; return false;
...@@ -637,7 +637,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -637,7 +637,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim)) if(!(valid_b_vector_size && valid_b_access_dim))
{ {
return false; return false;
...@@ -648,7 +649,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -648,7 +649,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_d_vector_size = const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension. // Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim)) if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_ds_access = false; valid_ds_access = false;
...@@ -662,7 +664,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -662,7 +664,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_e_vector_size = const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension. // Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_; const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
......
...@@ -516,26 +516,27 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -516,26 +516,27 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" {
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< std::endl; << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< std::endl; << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< std::endl; << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
}
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -644,7 +644,7 @@ struct ...@@ -644,7 +644,7 @@ struct
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " std::cout << "N " << arg.Conv_N_ << ", "
...@@ -664,9 +664,7 @@ struct ...@@ -664,9 +664,7 @@ struct
<< arg.input_left_pads_[1] << ", " << std::endl; << arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl; << arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
...@@ -684,7 +682,6 @@ struct ...@@ -684,7 +682,6 @@ struct
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " std::cout << "N " << arg.Conv_N_ << ", "
...@@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
<< arg.input_left_pads_[1] << ", " << std::endl; << arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl; << arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
...@@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " std::cout << "N " << arg.Conv_N_ << ", "
...@@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< arg.input_left_pads_[1] << ", " << std::endl; << arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl; << arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
...@@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
.GetLength(I5) .GetLength(I5)
<< "}" << std::endl; << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{ {
......
...@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
...@@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5) << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5)
<< " ) " << std::endl; << " ) " << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
...@@ -1393,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1393,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_navi3_supported())) ck::is_gfx11_supported()))
{ {
return false; return false;
} }
......
...@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1{" std::cout << "arg.a_grid_desc_k0_m_k1{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
...@@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl; << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim_m,
index_t NumDim_n,
index_t MPerThread,
index_t NPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n;
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MN>
static auto PadDescriptor_MN_2d(Desc_MN desc_mn,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mn.GetLength(I0);
const auto n = desc_mn.GetLength(I1);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto desc_mn_pad = transform_tensor_descriptor(
desc_mn,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return desc_mn_pad;
}
static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
// merge nd to 2d desc - [s0 * s1 * ...]
if constexpr(NumDim > 2)
{
const auto desc_mn = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n);
}
else
return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n);
}
template <index_t TupleSize>
static auto GenerateInOutGrid2dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 2)
{
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1);
}
else
{
return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1);
};
},
Number<TupleSize>{});
};
using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple,
OutGrid2dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config);
index_t num_threads_m = (gridSize * arg.blockSize_) / 16;
index_t num_threads_n = 16;
auto in_grid_2d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(arg.lengths_,
arg.inStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n);
},
Number<NumInput>{});
auto out_grid_2d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(arg.lengths_,
arg.outStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_2d<GridwiseElementwise,
InGrid2dDescTuple,
OutGrid2dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gridSize),
dim3(arg.blockSize_),
0,
in_grid_2d_desc_tuple,
out_grid_2d_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
num_threads_m,
num_threads_n);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
return true;
}
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{
return true;
}
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1))
valid = false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1))
valid = false;
});
return valid;
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim_m, // choose how to set dims
index_t NumDim_n,
index_t NumDim_k,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n + NumDim_k>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n + NumDim_k;
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
}
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
}
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MNK>
static auto PadDescriptor_MNK(Desc_MNK desc_mnk,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
{
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mnk.GetLength(I0);
const auto n = desc_mnk.GetLength(I1);
const auto k = desc_mnk.GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k;
const auto desc_mnk_pad =
transform_tensor_descriptor(desc_mnk,
make_tuple(make_right_pad_transform(m, pad_m),
make_right_pad_transform(n, pad_n),
make_right_pad_transform(k, pad_k)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return desc_mnk_pad;
}
static auto MakeDescriptor_MNK(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDim_m + NumDim_n, NumDim, 1>::type();
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
const auto kLengths = get_container_subset(tupleOfShape, kDimIds);
// merge nd to 3d desc - [s0 * s1 * ...]
if constexpr(NumDim > 3)
{
const auto desc_mnk = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(mLengths),
make_merge_transform(nLengths),
make_merge_transform(kLengths)),
make_tuple(mDimIds, nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadDescriptor_MNK(
desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
}
else
return PadDescriptor_MNK(
desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
}
template <index_t TupleSize>
static auto GenerateInOutGrid3dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 3)
{
return MakeDescriptor_MNK({1, 1, 1}, {1, 1, 1}, 1, 1, 1, 1, 1);
}
else
{
return MakeDescriptor_MNK({1}, {1}, 1, 1, 1, 1, 1);
};
},
Number<TupleSize>{});
}
using OutGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumOutput>{}));
using InGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumInput>{}));
using GridwiseElementwise = GridwiseElementwise_3D<InGrid3dDescTuple,
OutGrid3dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
KPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
static_assert(NumDim_k > 0, "");
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config) * arg.blockSize_;
index_t num_threads_m = gridSize / (16 * 16);
index_t num_threads_n = 16;
index_t num_threads_k = 16;
auto in_grid_3d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MNK(arg.lengths_,
arg.inStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
},
Number<NumInput>{});
auto out_grid_3d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MNK(arg.lengths_,
arg.outStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_3d<GridwiseElementwise,
InGrid3dDescTuple,
OutGrid3dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gridSize),
dim3(arg.blockSize_),
0,
in_grid_3d_desc_tuple,
out_grid_3d_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
num_threads_m,
num_threads_n,
num_threads_k);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
return true;
}
if(strides[vectorDim] >= scalarPerVector)
{
return true;
}
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1);
});
return valid;
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment