Commit e599063f authored by illsilin's avatar illsilin
Browse files

sync from the public repo

parents 5dbbf5d6 566b6480
......@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
......@@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
......@@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5)
<< " ) " << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
......@@ -1393,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported() || ck::is_navi4_supported()))
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return false;
}
......
......@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
......@@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim_m,
index_t NumDim_n,
index_t MPerThread,
index_t NPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n;
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MN>
static auto PadDescriptor_MN_2d(Desc_MN desc_mn,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mn.GetLength(I0);
const auto n = desc_mn.GetLength(I1);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto desc_mn_pad = transform_tensor_descriptor(
desc_mn,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return desc_mn_pad;
}
static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
// merge nd to 2d desc - [s0 * s1 * ...]
if constexpr(NumDim > 2)
{
const auto desc_mn = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n);
}
else
return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n);
}
template <index_t TupleSize>
static auto GenerateInOutGrid2dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 2)
{
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1);
}
else
{
return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1);
};
},
Number<TupleSize>{});
};
using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple,
OutGrid2dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config);
index_t num_threads_m = (gridSize * arg.blockSize_) / 16;
index_t num_threads_n = 16;
auto in_grid_2d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(arg.lengths_,
arg.inStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n);
},
Number<NumInput>{});
auto out_grid_2d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(arg.lengths_,
arg.outStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_2d<GridwiseElementwise,
InGrid2dDescTuple,
OutGrid2dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gridSize),
dim3(arg.blockSize_),
0,
in_grid_2d_desc_tuple,
out_grid_2d_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
num_threads_m,
num_threads_n);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
return true;
}
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{
return true;
}
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1))
valid = false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1))
valid = false;
});
return valid;
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim_m, // choose how to set dims
index_t NumDim_n,
index_t NumDim_k,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n + NumDim_k>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n + NumDim_k;
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
}
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
}
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MNK>
static auto PadDescriptor_MNK(Desc_MNK desc_mnk,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
{
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mnk.GetLength(I0);
const auto n = desc_mnk.GetLength(I1);
const auto k = desc_mnk.GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k;
const auto desc_mnk_pad =
transform_tensor_descriptor(desc_mnk,
make_tuple(make_right_pad_transform(m, pad_m),
make_right_pad_transform(n, pad_n),
make_right_pad_transform(k, pad_k)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return desc_mnk_pad;
}
static auto MakeDescriptor_MNK(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDim_m + NumDim_n, NumDim, 1>::type();
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
const auto kLengths = get_container_subset(tupleOfShape, kDimIds);
// merge nd to 3d desc - [s0 * s1 * ...]
if constexpr(NumDim > 3)
{
const auto desc_mnk = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(mLengths),
make_merge_transform(nLengths),
make_merge_transform(kLengths)),
make_tuple(mDimIds, nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadDescriptor_MNK(
desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
}
else
return PadDescriptor_MNK(
desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
}
template <index_t TupleSize>
static auto GenerateInOutGrid3dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 3)
{
return MakeDescriptor_MNK({1, 1, 1}, {1, 1, 1}, 1, 1, 1, 1, 1);
}
else
{
return MakeDescriptor_MNK({1}, {1}, 1, 1, 1, 1, 1);
};
},
Number<TupleSize>{});
}
using OutGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumOutput>{}));
using InGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumInput>{}));
using GridwiseElementwise = GridwiseElementwise_3D<InGrid3dDescTuple,
OutGrid3dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
KPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
static_assert(NumDim_k > 0, "");
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config) * arg.blockSize_;
index_t num_threads_m = gridSize / (16 * 16);
index_t num_threads_n = 16;
index_t num_threads_k = 16;
auto in_grid_3d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MNK(arg.lengths_,
arg.inStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
},
Number<NumInput>{});
auto out_grid_3d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MNK(arg.lengths_,
arg.outStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_3d<GridwiseElementwise,
InGrid3dDescTuple,
OutGrid3dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gridSize),
dim3(arg.blockSize_),
0,
in_grid_3d_desc_tuple,
out_grid_3d_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
num_threads_m,
num_threads_n,
num_threads_k);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
return false;
}
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
return true;
}
if(strides[vectorDim] >= scalarPerVector)
{
return true;
}
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1);
});
return valid;
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -9,8 +9,9 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
......@@ -23,7 +24,12 @@ template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim,
index_t MPerThread,
index_t BlockSize,
index_t M0PerBlock,
index_t M1PerBlock,
index_t M0PerThread,
index_t M1PerThread,
typename ThreadClusterArrangeOrder,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwiseImpl
......@@ -32,6 +38,9 @@ struct DeviceElementwiseImpl
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
......@@ -61,76 +70,145 @@ struct DeviceElementwiseImpl
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
static index_t GetLowestStrideDim(const std::array<index_t, NumDim>& strides)
{
index_t most_continous_dim = NumDim - 1;
index_t most_continous_dim_stride = strides[most_continous_dim];
for(index_t dim = 0; dim < NumDim; dim++)
{
if(strides[dim] < most_continous_dim_stride)
{
constexpr auto I0 = Number<0>{};
const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
most_continous_dim_stride = strides[dim];
most_continous_dim = dim;
}
}
return most_continous_dim;
}
static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize)
template <typename InOutDescriptor>
static auto PadInputOutputDescriptor(const InOutDescriptor& desc)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
const auto M0 = desc.GetLength(I0);
const auto M1 = desc.GetLength(I1);
const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0;
const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1;
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(NumDim > 1)
{
const auto desc_m = transform_tensor_descriptor(
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})),
make_tuple(Sequence<0>{}));
make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
else
return PadDescriptor_M_1d(desc, gridSize, blockSize);
return padded_desc;
}
template <index_t TupleSize>
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
static auto GenerateBatchDimsLenghtsTuple(const std::array<index_t, NumDim>& lengths,
const index_t M0_dim,
const index_t M1_dim)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 1)
// Generate batch dims, they will be merged to M0
// Add one more dim than needed in case that M0 is equal to M1
// If M0 is equal to M1, then will be one more batch dim
std::array<index_t, NumDim - 1> batch_dims;
index_t batch_dim = 0;
for(index_t i = 0; i < NumDim; i++)
{
return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
if(i != M0_dim && i != M1_dim)
{
batch_dims[batch_dim] = lengths[i];
batch_dim++;
}
}
// Add dummy dim if M0_dim is not equal to M1_dim
if(M0_dim != M1_dim && NumDim >= 2)
batch_dims[NumDim - 2] = 1;
return generate_tuple([&](auto I) { return batch_dims[I]; }, Number<NumDim - 1>{});
}
else
static auto MakeDescriptor(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& in_strides,
const std::array<index_t, NumDim>& out_strides,
const std::array<index_t, NumDim>& desc_strides)
{
return MakeDescriptor_M({1}, {1}, 1, 1);
};
},
Number<TupleSize>{});
const auto M0_dim = GetLowestStrideDim(out_strides);
const auto M1_dim = GetLowestStrideDim(in_strides);
// If M0_dim is equal to M1_dim, then make M0_dim dummy
const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim];
const auto M1 = lengths[M1_dim];
const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim];
const auto M1_stride = desc_strides[M1_dim];
const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim);
const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim);
const auto desc = make_naive_tensor_descriptor(
concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)),
concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride)));
// Merged batch dims with M0
const auto transforms =
make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))),
make_pass_through_transform(M1));
using BatchElemsSequence =
typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type;
const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence<NumDim>{});
const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{});
// desc: (merged_dims + M0, M1)
auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
return PadInputOutputDescriptor(merged_desc);
}
template <index_t NumTensors>
static auto GenerateInOutGridDescTuple()
{
std::array<index_t, NumDim> ones;
for(index_t d = 0; d < NumDim; d++)
{
ones[d] = 1;
}
return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); },
Number<NumTensors>{});
};
using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumInput>{}));
using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumOutput>{}));
using InGridDescTuple = decltype(GenerateInOutGridDescTuple<NumInput>());
using OutGridDescTuple = decltype(GenerateInOutGridDescTuple<NumOutput>());
using GridwiseElementwise = GridwiseElementwise_1D<InGrid1dDescTuple,
OutGrid1dDescTuple,
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<M0PerBlock, M1PerBlock>;
using GridwiseElementwiseOp = GridwiseElementwise<InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation,
MPerThread,
BlockSize,
M0PerBlock,
M1PerBlock,
M0PerThread,
M1PerThread,
ThreadClusterArrangeOrder,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
OutScalarPerVectorSeq,
I1,
I0>;
using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise<InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation,
BlockSize,
M0PerBlock,
M1PerBlock,
M0PerThread,
M1PerThread,
ThreadClusterArrangeOrder,
InScalarPerVectorSeq,
OutScalarPerVectorSeq,
I1,
I1>;
struct Argument : public BaseArgument
{
......@@ -144,8 +222,7 @@ struct DeviceElementwiseImpl
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256)
elementwise_op_(elementwise_op)
{
in_dev_buffers_ = generate_tuple(
[&](auto I) {
......@@ -170,45 +247,67 @@ struct DeviceElementwiseImpl
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config);
auto in_grid_1d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_);
auto in_grid_desc_tuple = generate_tuple(
[&](auto src_i) {
// Use Strides from first tensor to assert that M0 dim and
// M1 dim are the same for each tensor.
return MakeDescriptor(arg.lengths_,
arg.inStridesArray_[I0],
arg.outStridesArray_[I0],
arg.inStridesArray_[src_i]);
},
Number<NumInput>{});
auto out_grid_1d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_);
auto out_grid_desc_tuple = generate_tuple(
[&](auto dst_i) {
return MakeDescriptor(arg.lengths_,
arg.inStridesArray_[I0],
arg.outStridesArray_[I0],
arg.outStridesArray_[dst_i]);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
InGrid1dDescTuple,
OutGrid1dDescTuple,
const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number<I0>{});
const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number<I1>{});
const auto block_2_tile_map = Block2TileMap(M0, M1);
const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1);
const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) ==
GetLowestStrideDim(arg.outStridesArray_[I0]);
const auto kernel = in_out_same_vector_dim
? kernel_elementwise<GridwiseElementwiseOpSameInOutVectorDim,
InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation>
: kernel_elementwise<GridwiseElementwiseOp,
InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gridSize),
dim3(arg.blockSize_),
dim3(grid_size),
dim3(BlockSize),
0,
in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
in_grid_desc_tuple,
out_grid_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
block_2_tile_map,
arg.elementwise_op_);
return elapsed_time;
}
......@@ -223,35 +322,40 @@ struct DeviceElementwiseImpl
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.lengths_.back() % MPerThread != 0)
return false;
const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]);
const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]);
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector) {
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0)
index_t scalarPerVector,
index_t M_dim) {
if(scalarPerVector == 1)
{
return true;
if(strides.back() != 1 && scalarPerVector == 1)
}
if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0)
{
return true;
}
return false;
};
bool valid = true;
bool is_valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false;
static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 &&
M1PerThread % InScalarPerVectorSeq::At(I) == 0);
is_valid &= IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false;
static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 &&
M1PerThread % OutScalarPerVectorSeq::At(I) == 0);
is_valid &= IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim);
});
return valid;
return is_valid;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -302,23 +406,18 @@ struct DeviceElementwiseImpl
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwiseImpl<" ;
str << "NumDim_" << NumDim << ",";
str << "MPerThread_" << MPerThread << ",";
str << "InScalarPerVector";
static_for<0, InScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << InScalarPerVectorSeq::At(i).value; });
str << ",";
str << "OutScalarPerVector";
static_for<0, OutScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << OutScalarPerVectorSeq::At(i).value; });
str << ">";
str << "DeviceElementwiseImpl<";
str << NumDim << ", ";
str << BlockSize << ", ";
str << M0PerBlock << ", ";
str << M1PerBlock << ", ";
str << M0PerThread << ", ";
str << M1PerThread << ">";
// clang-format on
return str.str();
}
}; // namespace device
};
} // namespace device
} // namespace tensor_operation
......
......@@ -19,6 +19,10 @@ namespace ck {
namespace tensor_operation {
namespace device {
/**
* \note This structure is deprecated (left for backwards compatibility). Please use
* DeviceElementwiseImpl from device_elementwise_dynamic_vector_dims_impl.hpp.
*/
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
......
......@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported() || ck::is_navi4_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
......@@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
......@@ -349,7 +349,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
......@@ -536,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
}
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported() || ck::is_navi4_supported())
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
......@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg)
{
if(ck::is_navi2_supported() || ck::is_navi3_supported())
if(ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
return GridwiseGemm::CheckValidity(karg);
}
......
......@@ -10,110 +10,30 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename AsPointer,
typename BsPointer,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AsGridDesc_AK0_M_AK1,
typename BsGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_as_grid;
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = as_grid_desc_ak0_m_ak1;
ignore = bs_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename CLayout,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -132,131 +52,56 @@ template <typename AsLayout,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
CLayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleABD_Xdl_CShuffle;
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
#if 0
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideAs)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, AsLayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideAs, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, AsLayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideAs));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideBs)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BsLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideBs));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BsLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideBs, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
#endif
using ComputeDataType = EDataType;
using ALayout = remove_cvref_t<tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<tuple_element_t<0, BsLayout>>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle<
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
AsDataType,
BsDataType,
ComputeDataType,
AccDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
EDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
......@@ -285,364 +130,476 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
// desc for problem definition
using AsGridDesc_M_K =
remove_cvref_t<decltype(GridwiseGemm::template MakeAsGridDescriptor_M_K<AsLayout, GemmSpec>(
{}, {}, {}))>;
using BsGridDesc_N_K =
remove_cvref_t<decltype(GridwiseGemm::template MakeBsGridDescriptor_N_K<BsLayout, GemmSpec>(
{}, {}, {}))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
using EGridDesc_M_N =
decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(1, 1, 1));
// desc for blockwise copy
using AsGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(
AsGridDesc_M_K{}))>;
using BsGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(
BsGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
{
Argument(std::array<const void*, NumATensor> p_as_grid,
std::array<const void*, NumBTensor> p_bs_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
std::array<index_t, NumATensor> StrideAs,
std::array<index_t, NumBTensor> StrideBs,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_as_grid_{},
p_bs_grid_{},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
as_grid_desc_m_k_{},
bs_grid_desc_n_k_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
MRaw, NRaw, StrideE)},
as_grid_desc_ak0_m_ak1_{},
bs_grid_desc_bk0_n_bk1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw}
{
// populate pointer, desc for As
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
// A pointer
p_as_grid_(i) = static_cast<const ADataType*>(p_as_grid[i]);
// A desc
as_grid_desc_m_k_(i) =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
MRaw, KRaw, StrideAs[i]);
});
// populate pointer, desc for Bs
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// B pointer
p_bs_grid_(i) = static_cast<const BDataType*>(p_bs_grid[i]);
using Argument = typename GridwiseGemm::Argument;
// B desc
bs_grid_desc_n_k_(i) =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
KRaw, NRaw, StrideBs[i]);
});
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_,
bs_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
// Invoker
struct Invoker : public BaseInvoker
{
as_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
bs_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
void Print() const
if(!GridwiseGemm::CheckValidity(arg))
{
// std::cout << "A[M, K]: " << as_grid_desc_m_k_ << std::endl;
// std::cout << "B[N, K]: " << bs_grid_desc_n_k_ << std::endl;
// static_for<0, NumDTensor, 1>{}(
//[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
// std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
// private:
// pointers
typename GridwiseGemm::AsGridPointer p_as_grid_;
typename GridwiseGemm::BsGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
// tensor descriptors for problem definiton
AsGridDesc_M_K as_grid_desc_m_k_;
BsGridDesc_N_K bs_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
float ave_time = 0;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_;
BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
const auto Run = [&](const auto& kernel) {
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
// for checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
};
// Invoker
struct Invoker : public BaseInvoker
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
#if 0
if(arg.KBatch > 1)
{
using Argument = DeviceOp::Argument;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
arg.bs_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
const auto kernel = kernel_gemm_multiple_abd_xdl_cshuffle<
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
typename GridwiseGemm::AsGridPointer,
typename GridwiseGemm::BsGridPointer,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AsGridDesc_AK0_M_AK1,
DeviceOp::BsGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_as_grid_,
arg.p_bs_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.as_grid_desc_ak0_m_ak1_,
arg.bs_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1);
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
return launch_kernel(integral_constant<bool, true>{});
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
#endif
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
return launch_kernel(integral_constant<bool, false>{});
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
};
static bool IsSupportedArgument(const Argument& arg)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(!ck::is_xdl_supported())
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
return false;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
// check vector load/store
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
bool all_valid = true;
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
all_valid = false;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
all_valid = false;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
all_valid = false;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
}
else
#endif
{
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
all_valid = false;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
}
else
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
#if 0
if(arg.KBatch > 1)
{
all_valid = false;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
all_valid = false;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
#if 0
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
});
// check vector load of Ds
// only support RowMajor for now
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return ave_time;
}
if constexpr(!is_same_v<DLayout, Row>)
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
all_valid = false;
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
});
};
if(!all_valid)
static constexpr bool IsValidCompilationParameter()
{
return false;
// TODO: properly implement this check
return true;
}
// check vector store of E
// only support RowMajor for now
if constexpr(is_same_v<ELayout, Row>)
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
if(!ck::is_xdl_supported())
{
return false;
}
}
else
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
arg.bs_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
......@@ -664,8 +621,27 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
CElementwiseOperation c_element_op)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
static_assert(is_same<ALayout_, ALayout>::value, "");
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
static_assert(is_same<BLayout_, BLayout>::value, "");
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout_ = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
static_assert(is_same<DLayout_, CLayout>::value, "");
});
return Argument{p_as,
p_bs,
p_ds,
......@@ -677,16 +653,16 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
StrideBs,
StrideDs,
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op};
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
......@@ -699,7 +675,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(p_as,
p_bs,
......@@ -712,9 +688,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
StrideBs,
StrideDs,
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
c_element_op);
}
// polymorphic
......@@ -728,35 +705,41 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMultipleABD_Xdl_CShuffle"
str << "DeviceGemmXdlUniversal"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
......
......@@ -553,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
......@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported() || ck::is_navi4_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -510,7 +510,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
......@@ -528,7 +528,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
<< "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......
......@@ -450,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported() || ck::is_navi4_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_,
stream_config.rotating_count,
arg_.M * arg_.K * sizeof(ADataType),
arg_.K * arg_.N * sizeof(BDataType));
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg_.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
}
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmXdlUniversal"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
......@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......
......@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
......
......@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_navi3_supported() || ck::is_navi4_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -138,34 +138,6 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
WeiElementwiseOperation,
OutElementwiseOperation>
{
// 1d
static constexpr bool is_NWGK_GKXC_NWGC =
is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
static constexpr bool is_GNWK_GKXC_GNWC =
is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
// 2d
static constexpr bool is_NHWGK_GKYXC_NHWGC =
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
static constexpr bool is_GNHWK_GKYXC_GNHWC =
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
// 3d
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
using DeviceOp = DeviceGroupedConvBwdWeight_Dl;
using ADataType = OutDataType;
......@@ -1066,9 +1038,15 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
if(arg.k_batch_ != 1)
return false;
if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) ||
(NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) ||
(NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))))
if constexpr(!((NDimSpatial == 1 &&
(is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 2 &&
(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 3 &&
(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))))
{
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment