"vscode:/vscode.git/clone" did not exist on "f5073f49272badee72350e9288dc3c99780c929d"
Commit 5f4c1ddb authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Clean the code and comments

parent 0661e8d2
...@@ -44,7 +44,7 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -44,7 +44,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed # FIXME: re-enable this example as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
...@@ -58,9 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) ...@@ -58,9 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
if(GPU_TARGETS MATCHES "gfx90a") if(GPU_TARGETS MATCHES "gfx90a")
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp) add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
add_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
endif()
endif() endif()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
......
...@@ -6,26 +6,12 @@ ...@@ -6,26 +6,12 @@
#include "common.hpp" #include "common.hpp"
#define USING_DIRECT_LOADS 1 #define USING_DIRECT_LOADS 1
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#if USING_DIRECT_LOADS #if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else #else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif #endif
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
using F32 = float; using F32 = float;
using ADataType = F32; using ADataType = F32;
......
...@@ -67,17 +67,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -67,17 +67,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"The number of threads cannot be less than the number of elements in " "The number of threads cannot be less than the number of elements in "
"thread cluster lengths."); "thread cluster lengths.");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or const auto thread_cluster_idx =
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx; const auto thread_data_idx_begin = thread_cluster_idx;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
}
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -103,11 +99,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -103,11 +99,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(ThreadGroup::GetNumOfThread() != thread_cluster_desc_.GetElementSize() &&
ThreadGroup::GetThreadId() >= thread_cluster_desc_.GetElementSize())
{
return;
}
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global, static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer."); "Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
...@@ -120,21 +111,19 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -120,21 +111,19 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"DstBuffer and DstData data types must be consistent."); "DstBuffer and DstData data types must be consistent.");
constexpr auto dst_access_lengths = thread_slice_lengths; constexpr auto dst_access_lengths = thread_slice_lengths;
constexpr auto dst_dim_access_order = Sequence<0, 1, 2>{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
const auto dst_forward_steps = generate_steps(dst_desc, 1); const auto dst_forward_steps = generate_steps(dst_desc, 1);
const auto dst_backward_steps = generate_steps(dst_desc, -1); const auto dst_backward_steps = generate_steps(dst_desc, -1);
const auto src_forward_steps = generate_steps(src_desc, 1); const auto src_forward_steps = generate_steps(src_desc, 1);
const auto src_backward_steps = generate_steps(src_desc, -1); const auto src_backward_steps = generate_steps(src_desc, -1);
// loop over tensor and copy // Loop over the destination block and copy data.
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) { static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
const auto src_offset = src_coord_.GetOffset(); const auto src_offset = src_coord_.GetOffset();
const auto dst_offset = dst_coord_.GetOffset(); const auto dst_offset = dst_coord_.GetOffset();
// Check if src data is not in the logic padding area.
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
...@@ -145,11 +134,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -145,11 +134,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
StaticallyIndexedArray<bool, nDim> move_on_dim_; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
}); });
}); });
...@@ -157,7 +145,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -157,7 +145,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
} }
(); ();
// judge move forward or move backward // Decide whether to move forward or backward.
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_; StaticallyIndexedArray<bool, nDim> forward_sweep_;
...@@ -167,7 +155,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -167,7 +155,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
}); });
forward_sweep_(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
...@@ -181,33 +169,26 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -181,33 +169,26 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_tensor_coordinate( move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[dst_dim_access_order[i]]);
} }
else else
{ {
move_tensor_coordinate( move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[dst_dim_access_order[i]]);
} }
} }
}); });
}); });
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow(dst_desc); ResetDstSliceWindow(dst_desc);
} }
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or src_slice_origin_ = src_slice_origin_ + step;
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
{
src_slice_origin_ = src_slice_origin_ + step;
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
}
} }
template <typename DescType> template <typename DescType>
......
...@@ -191,7 +191,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -191,7 +191,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
} }
} }
// polymorphic
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
...@@ -206,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -206,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check vector load/store // Check vector load/store.
{ {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A // Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2) if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
...@@ -221,7 +220,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -221,7 +220,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
} }
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1) else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{ {
return false; return false;
...@@ -232,7 +230,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -232,7 +230,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check vector load of B // Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2) if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
...@@ -242,7 +240,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -242,7 +240,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
} }
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1) else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{ {
return false; return false;
...@@ -253,8 +250,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -253,8 +250,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check vector load of Ds // Check vector load of Ds.
// only support RowMajor for now // For now, only the RowMajor layout is supported.
bool all_valid = true; bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -271,8 +268,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -271,8 +268,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check vector store of E // Check vector load of E.
// only support RowMajor for now // For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>) if constexpr(is_same_v<ELayout, Row>)
{ {
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
...@@ -293,7 +290,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -293,7 +290,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
arg.block_2_etile_map_); arg.block_2_etile_map_);
} }
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
...@@ -332,7 +328,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -332,7 +328,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -365,13 +360,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -365,13 +360,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
cde_element_op); cde_element_op);
} }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
...@@ -118,7 +118,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -118,7 +118,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
...@@ -186,7 +185,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -186,7 +185,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
} }
} }
// polymorphic
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
...@@ -201,12 +199,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -201,12 +199,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false; return false;
} }
// check vector load/store // Check vector load/store.
{ {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A // Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2) if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
...@@ -216,7 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -216,7 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
} }
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1) else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{ {
return false; return false;
...@@ -227,7 +224,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -227,7 +224,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false; return false;
} }
// check vector load of B // Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2) if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
...@@ -237,7 +234,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -237,7 +234,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
} }
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1) else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{ {
return false; return false;
...@@ -248,8 +244,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -248,8 +244,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false; return false;
} }
// check vector store of E // Check vector load of E.
// only support RowMajor for now // For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>) if constexpr(is_same_v<ELayout, Row>)
{ {
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
...@@ -270,7 +266,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -270,7 +266,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
arg.block_2_etile_map_); arg.block_2_etile_map_);
} }
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
...@@ -310,7 +305,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -310,7 +305,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -344,13 +338,11 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -344,13 +338,11 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
cde_element_op); cde_element_op);
} }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
...@@ -55,8 +55,8 @@ __global__ void ...@@ -55,8 +55,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -173,7 +173,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -173,7 +173,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1), make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1)); make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
...@@ -181,7 +181,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -181,7 +181,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1), make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
...@@ -217,11 +217,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -217,11 +217,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1); constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
...@@ -230,7 +229,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -230,7 +229,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto b_block_space_size_aligned = math::integer_least_multiple( constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle.
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
...@@ -316,11 +315,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -316,11 +315,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
const std::array<index_t, NumDTensor>& DsStride) const std::array<index_t, NumDTensor>& DsStride)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); },
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
...@@ -329,7 +324,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -329,7 +324,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// A desc for source in blockwise copy // A desc for source in blockwise copy.
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{ {
...@@ -345,7 +340,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -345,7 +340,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// B desc for source in blockwise copy // B desc for source in blockwise copy.
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{ {
...@@ -361,7 +356,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -361,7 +356,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// E desc for destination in blockwise copy // E desc for destination in blockwise copy.
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
...@@ -381,7 +376,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -381,7 +376,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return e_grid_desc_mblock_mperblock_nblock_nperblock; return e_grid_desc_mblock_mperblock_nblock_nperblock;
} }
// Ds desc for source in blockwise copy // Ds desc for source in blockwise copy.
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{ {
...@@ -392,7 +387,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -392,7 +387,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
...@@ -411,10 +405,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -411,10 +405,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>; EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap = remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; using Block2ETileMap = remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k, const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n, const DsGridDesc_M_N& ds_grid_desc_m_n,
...@@ -439,7 +431,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -439,7 +431,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
const auto AK = a_grid_desc_m_k.GetLength(I1); const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1); const auto BK = b_grid_desc_n_k.GetLength(I1);
// check consistency of desc // Check the consistency of descriptors.
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{ {
return false; return false;
...@@ -457,28 +449,26 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -457,28 +449,26 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check tile size // Check the tile size.
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{ {
return false; return false;
} }
// check gridwise gemm pipeline // Check gridwise gemm pipeline.
const auto num_k_loop = AK / KPerBlock; const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
return false; return false;
} }
// check block-to-E-tile // Check block-to-E-tile.
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{ {
return false; return false;
} }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // Check tensor size: cannot exceed 2GB.
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31); constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
...@@ -522,7 +512,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -522,7 +512,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
// elementwise operations are not supported for A and B, left only for the API consistency // Elementwise operations are not supported for A and B, arguments left only for the API
// consistency.
(void)a_element_op; (void)a_element_op;
(void)b_element_op; (void)b_element_op;
...@@ -543,7 +534,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -543,7 +534,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // Divide block work by [M, N].
const auto block_work_idx = const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -555,7 +546,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -555,7 +546,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return; return;
} }
// HACK: this forces m/n_block_data_idx_on_grid into SGPR // This forces m/n_block_data_idx_on_grid into SGPR.
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
...@@ -564,13 +555,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -564,13 +555,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto max_lds_align = math::lcm(AK1, BK1); constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, destination of blockwise copy.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, destination of blockwise copy.
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock, ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0PerBlock, MPerBlock, AK1>, Sequence<AK0PerBlock, MPerBlock, AK1>,
...@@ -588,7 +578,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -588,7 +578,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock, ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0PerBlock, NPerBlock, BK1>, Sequence<BK0PerBlock, NPerBlock, BK1>,
...@@ -612,7 +601,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -612,7 +601,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// b_mtx[K0PerBlock, NPerBlock] is in LDS // b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
...@@ -634,7 +622,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -634,7 +622,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -648,7 +636,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -648,7 +636,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>(); GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
...@@ -672,7 +659,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -672,7 +659,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // Shuffle C and write out.
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...@@ -723,8 +710,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -723,8 +710,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple( make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory // Calculate the origin of thread output tensor on global memory.
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
...@@ -751,7 +737,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -751,7 +737,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS // Shuffle: threadwise copy C from VGPR to LDS.
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -783,7 +769,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -783,7 +769,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors // A tuple of reference to C/Ds tensor descriptors.
const auto c_ds_desc_refs = concat_tuple_of_reference( const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie( generate_tie(
...@@ -791,7 +777,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -791,7 +777,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{})); Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors // A tuple of reference to C/Ds grid buffers.
const auto c_ds_buf_refs = concat_tuple_of_reference( const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf), tie(c_shuffle_block_buf),
generate_tie( generate_tie(
...@@ -799,7 +785,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -799,7 +785,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{ return ds_grid_buf[i]; }, { return ds_grid_buf[i]; },
Number<NumDTensor>{})); Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy // A tuple of starting index of C/Ds blockwise copy.
const auto idx_c_ds_block_begin = container_concat( const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)), make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple( generate_tuple(
...@@ -808,7 +794,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -808,7 +794,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}, },
Number<NumDTensor>{})); Number<NumDTensor>{}));
// blockwise copy C/D/E between LDS and global // Blockwise copy C/D/E between LDS and global.
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
...@@ -816,8 +802,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -816,8 +802,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype(c_ds_desc_refs), decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
// support arbitray type
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
...@@ -838,7 +823,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -838,7 +823,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op}; cde_element_op};
// space filling curve for threadwise C in VGPR before shuffle // Space filling curve for threadwise C in VGPR before shuffle.
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
...@@ -851,7 +836,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -851,7 +836,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
M4, M4,
1>>{}; 1>>{};
// space filling curve for shuffled blockwise C/D/E // Space filling curve for shuffled blockwise C/D/E.
constexpr auto sfc_cde_block = constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
...@@ -865,20 +850,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -865,20 +850,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // Make sure it's safe to write to LDS.
block_sync_lds(); block_sync_lds();
// each thread write its data from VGPR to LDS // Each thread write its data from VGPR to LDS.
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf, c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS // Make sure it's safe to read from LDS.
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global // Each block copy its data from LDS to global.
cde_block_copy_lds_and_global.Run( cde_block_copy_lds_and_global.Run(
c_ds_desc_refs, c_ds_desc_refs,
c_ds_buf_refs, c_ds_buf_refs,
...@@ -890,13 +875,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -890,13 +875,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto cde_lds_and_global_step = constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id); sfc_cde_block.GetForwardStep(access_id);
// move on Ds // Move on Ds.
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow( cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step); c_ds_desc_refs, i + I1, cde_lds_and_global_step);
}); });
// move on E // Move on E.
cde_block_copy_lds_and_global.MoveDstSliceWindow( cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0, I0,
...@@ -942,19 +927,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -942,19 +927,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
NRaw_{NRaw}, NRaw_{NRaw},
KRaw_{KRaw} KRaw_{KRaw}
{ {
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
}); });
// populate desc for Ds/E
if(CheckValidity(a_grid_desc_m_k_, if(CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
...@@ -978,19 +956,19 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -978,19 +956,19 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
} }
// pointers // Pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
DsGridPointer p_ds_grid_; DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// tensor descriptors for problem definiton // Tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy // Tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -1000,12 +978,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -1000,12 +978,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
// element-wise op // element-wise ops
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// for checking vector load/store // For checking vector load/store
index_t MRaw_; index_t MRaw_;
index_t NRaw_; index_t NRaw_;
index_t KRaw_; index_t KRaw_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment