Commit e599063f authored by illsilin's avatar illsilin
Browse files

sync from the public repo

parents 5dbbf5d6 566b6480
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DsDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeightMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsLayout::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_in_grid,
void* p_wei_grid,
const void* p_out_grid,
const std::array<const void*, NumDTensor>& p_ds,
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const ck::index_t split_k) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
* \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed).
* \tparam BComputeType Compute data type for B tensor (default: AComputeType).
*/
template <index_t NDimSpatial,
typename ALayout,
......@@ -54,12 +55,13 @@ template <index_t NDimSpatial,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputeType =
typename AComputeType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>())> // ComputeType is InputType by default (first
ADataType>()), // AComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeType = AComputeType>
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
{
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct GemmMultiABDDesc
{
ck::index_t M_, N_, K_;
std::vector<ck::index_t> stride_As_;
std::vector<ck::index_t> stride_Bs_;
std::vector<ck::index_t> stride_Ds_;
ck::index_t stride_C_;
};
/*
* \brief Grouped Gemm Multi ABD
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam AsLayout A layouts (tuple).
* \tparam BsLayout B layouts (tuple).
* \tparam DsLayout Ds layouts (tuple).
* \tparam ELayout Output layout.
* \tparam AsDataType A data types (tuple).
* \tparam BsDataType B data types (tuple).
* \tparam DsDataType D data types (tuple).
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation C elementwise operation.
*/
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmMultiABD : public BaseOperator
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor");
static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor");
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
/*
* \brief Make argument pointer for grouped gemm multi abd.
*
* \param p_as A pointers to the A.
* \param p_bs A pointers to the B.
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the E.
* \param gemm_desc Gemm descriptors for each group.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_as,
std::vector<std::array<const void*, NumBTensor>>& p_bs,
std::vector<std::array<const void*, NumDTensor>>& p_ds,
std::vector<void*>& p_e,
std::vector<GemmMultiABDDesc>& gemm_desc,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual void SetElementwiseOps(BaseArgument* p_arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include "device_grouped_gemm_multi_abd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct GroupedGemmMultiABDKernelArgument
{
std::array<const void*, NumATensor> p_as_grid;
std::array<const void*, NumBTensor> p_bs_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
};
/*
* \brief Grouped Gemm Multi ABD Fixed NK
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam AsLayout A layouts (tuple).
* \tparam BsLayout B layouts (tuple).
* \tparam DsLayout Ds layouts (tuple).
* \tparam ELayout Output layout.
* \tparam AsDataType A data types (tuple).
* \tparam BsDataType B data types (tuple).
* \tparam DsDataType D data types (tuple).
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation C elementwise operation.
*/
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmMultipleDKernelArguments
{
__host__ __device__
GroupedGemmMultipleDKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmTileLoopKernelArguments
{
__host__ __device__
GroupedGemmTileLoopKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -834,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported() || ck::is_navi4_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -649,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
bool pass = true;
pass = pass && arg.K_ % K1 == 0;
......
......@@ -587,13 +587,14 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
BatchStrideD1s,
BatchStrideE1}
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", "
<< d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0)
<< ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
......@@ -610,7 +611,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<< std::endl;
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
#endif
}
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
......
......@@ -658,7 +658,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
{
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
......@@ -672,13 +673,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
<< "}" << std::endl;
std::cout << "arg.reduce_grid_desc_m_{ "
<< arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl;
}
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......
......@@ -859,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported() || ck::is_navi4_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -1436,7 +1436,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
arg.Print();
#endif
}
if(!ck::is_xdl_supported())
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -80,42 +80,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto MPerThread = Number<4>{};
static constexpr index_t MPerThread =
MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t NPerThread =
NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto AScalarPerVector = Number<4>{};
static constexpr auto BScalarPerVector = Number<4>{};
static constexpr auto CScalarPerVector = Number<4>{};
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
template <typename Desc_M_N>
static auto PadDescriptor_M_N(Desc_M_N desc)
{
const auto M = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(M, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
const auto M = desc.GetLength(I0);
const auto N = desc.GetLength(I1);
const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M;
const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N;
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_right_pad_transform(M, pad_M), make_right_pad_transform(N, pad_N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return padded_desc;
}
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides,
index_t gridSize,
index_t blockSize)
static auto MakeDescriptor_M_N(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
const auto desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
return PadDescriptor_M_N(desc);
}
// GridwiseGemm
......@@ -166,7 +165,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1}));
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem
......@@ -195,17 +194,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_grid_imag{p_c_grid_imag_},
p_aux_grid{p_workspace}
{
const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_));
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
c_grid_desc_m =
DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize);
c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1});
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
c_grid_desc_m =
DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_});
}
p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
......@@ -220,7 +215,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_imag;
CDataType* p_aux_grid;
CDataType* p_aux_2_grid;
CGridDesc_M c_grid_desc_m;
CGridDesc_M_N c_grid_desc_m_n;
};
// Invoker
......@@ -248,39 +243,62 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using Add = ck::tensor_operation::element_wise::Add;
using Subtract = ck::tensor_operation::element_wise::Subtract;
using GridwiseBinAdd =
GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
Tuple<CGridDesc_M>,
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseBinAdd = GridwiseElementwise<Tuple<CGridDesc_M_N, CGridDesc_M_N>,
Tuple<CGridDesc_M_N>,
Tuple<const CDataType*, const CDataType*>,
Tuple<CDataType*>,
Block2TileMap,
Add,
BlockSize,
MPerBlock,
NPerBlock,
MPerThread,
NPerThread,
Sequence<0, 1>,
Sequence<AScalarPerVector, BScalarPerVector>,
Sequence<CScalarPerVector>>;
Sequence<CScalarPerVector>,
I1,
I1>;
using GridwiseBinSubtract =
GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
Tuple<CGridDesc_M>,
GridwiseElementwise<Tuple<CGridDesc_M_N, CGridDesc_M_N>,
Tuple<CGridDesc_M_N>,
Tuple<const CDataType*, const CDataType*>,
Tuple<CDataType*>,
Block2TileMap,
Subtract,
BlockSize,
MPerBlock,
NPerBlock,
MPerThread,
NPerThread,
Sequence<0, 1>,
Sequence<AScalarPerVector, BScalarPerVector>,
Sequence<CScalarPerVector>>;
Sequence<CScalarPerVector>,
I1,
I1>;
const index_t M = arg.c_grid_desc_m_n.GetLength(I0);
const index_t N = arg.c_grid_desc_m_n.GetLength(I1);
const auto block_2_tile_map = Block2TileMap(M, N);
const auto add_kernel = kernel_elementwise_1d<GridwiseBinAdd,
Tuple<CGridDesc_M, CGridDesc_M>,
Tuple<CGridDesc_M>,
const auto add_kernel = kernel_elementwise<GridwiseBinAdd,
Tuple<CGridDesc_M_N, CGridDesc_M_N>,
Tuple<CGridDesc_M_N>,
Tuple<const CDataType*, const CDataType*>,
Tuple<CDataType*>,
Block2TileMap,
Add>;
const auto subtract_kernel =
kernel_elementwise_1d<GridwiseBinSubtract,
Tuple<CGridDesc_M, CGridDesc_M>,
Tuple<CGridDesc_M>,
kernel_elementwise<GridwiseBinSubtract,
Tuple<CGridDesc_M_N, CGridDesc_M_N>,
Tuple<CGridDesc_M_N>,
Tuple<const CDataType*, const CDataType*>,
Tuple<CDataType*>,
Block2TileMap,
Subtract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
......@@ -318,11 +336,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
make_tuple(arg.c_grid_desc_m_n),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_real),
block_2_tile_map,
Subtract{});
ave_time += launch_and_time_kernel(stream_config,
......@@ -352,11 +371,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
make_tuple(arg.c_grid_desc_m_n),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_imag),
block_2_tile_map,
Add{});
}
else
......@@ -394,11 +414,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
make_tuple(arg.c_grid_desc_m_n),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_real),
block_2_tile_map,
Subtract{});
ave_time += launch_and_time_kernel(stream_config,
......@@ -428,11 +449,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m),
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
make_tuple(arg.c_grid_desc_m_n),
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(arg.p_c_grid_imag),
block_2_tile_map,
Add{});
}
......
......@@ -663,7 +663,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
if(!((valid_a_vector_size && valid_a_access_dim) ||
ABlockTransferSrcScalarPerVector == 1))
{
valid_as_access = false;
}
......@@ -682,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim))
if(!((valid_b_vector_size && valid_b_access_dim) ||
BBlockTransferSrcScalarPerVector == 1))
{
valid_bs_access = false;
}
......@@ -698,7 +700,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
if(!((valid_d_vector_size && valid_d_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{
valid_ds_access = false;
}
......@@ -712,7 +715,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
if(!(valid_e_vector_size && valid_e_access_dim))
if(!((valid_e_vector_size && valid_e_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -625,7 +625,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim))
{
return false;
......@@ -635,7 +636,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim))
{
return false;
......@@ -646,7 +648,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim))
{
valid_ds_access = false;
......@@ -660,7 +663,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim))
{
return false;
......
......@@ -516,7 +516,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
......@@ -535,7 +536,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
}
#endif
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
......
......@@ -644,7 +644,7 @@ struct
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
......@@ -664,9 +664,7 @@ struct
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
......@@ -684,7 +682,6 @@ struct
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
......@@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
......@@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", "
......@@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
......@@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
.GetLength(I5)
<< "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......@@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment