Commit 5f4c1ddb authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Clean the code and comments

parent 0661e8d2
......@@ -44,7 +44,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
# FIXME: re-enable this example as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
......@@ -58,9 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
if(GPU_TARGETS MATCHES "gfx90a")
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
if(result EQUAL 0)
add_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
endif()
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
endif()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
......
......@@ -6,26 +6,12 @@
#include "common.hpp"
#define USING_DIRECT_LOADS 1
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
using F32 = float;
using ADataType = F32;
......
......@@ -67,18 +67,14 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths.");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
......@@ -103,11 +99,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(ThreadGroup::GetNumOfThread() != thread_cluster_desc_.GetElementSize() &&
ThreadGroup::GetThreadId() >= thread_cluster_desc_.GetElementSize())
{
return;
}
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
......@@ -121,20 +112,18 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"DstBuffer and DstData data types must be consistent.");
constexpr auto dst_access_lengths = thread_slice_lengths;
constexpr auto dst_dim_access_order = Sequence<0, 1, 2>{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
const auto dst_forward_steps = generate_steps(dst_desc, 1);
const auto dst_backward_steps = generate_steps(dst_desc, -1);
const auto src_forward_steps = generate_steps(src_desc, 1);
const auto src_backward_steps = generate_steps(src_desc, -1);
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// Loop over the destination block and copy data.
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
const auto src_offset = src_coord_.GetOffset();
const auto dst_offset = dst_coord_.GetOffset();
// Check if src data is not in the logic padding area.
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
......@@ -145,11 +134,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
});
});
......@@ -157,7 +145,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
}
();
// judge move forward or move backward
// Decide whether to move forward or backward.
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
......@@ -167,7 +155,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
......@@ -181,34 +169,27 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[dst_dim_access_order[i]]);
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[dst_dim_access_order[i]]);
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
}
}
});
});
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow(dst_desc);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
src_slice_origin_ = src_slice_origin_ + step;
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
}
}
template <typename DescType>
__device__ auto generate_steps(const DescType& desc, int sign)
......
......@@ -191,7 +191,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
......@@ -206,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false;
}
// check vector load/store
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
// Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
......@@ -221,7 +220,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
......@@ -232,7 +230,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false;
}
// check vector load of B
// Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
......@@ -242,7 +240,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
......@@ -253,8 +250,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false;
}
// check vector load of Ds
// only support RowMajor for now
// Check vector load of Ds.
// For now, only the RowMajor layout is supported.
bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -271,8 +268,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false;
}
// check vector store of E
// only support RowMajor for now
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
......@@ -293,7 +290,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
arg.block_2_etile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
......@@ -332,7 +328,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
......@@ -365,13 +360,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
......
......@@ -118,7 +118,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
......@@ -186,7 +185,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
......@@ -201,12 +199,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false;
}
// check vector load/store
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
// Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
......@@ -216,7 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
......@@ -227,7 +224,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false;
}
// check vector load of B
// Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
......@@ -237,7 +234,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
......@@ -248,8 +244,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false;
}
// check vector store of E
// only support RowMajor for now
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
......@@ -270,7 +266,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
arg.block_2_etile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
......@@ -310,7 +305,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
......@@ -344,13 +338,11 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
......
......@@ -55,8 +55,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -173,7 +173,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
......@@ -181,7 +181,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
......@@ -217,11 +217,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
......@@ -230,7 +229,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle.
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
......@@ -316,11 +315,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]);
},
[&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); },
Number<NumDTensor>{});
}
......@@ -329,7 +324,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// A desc for source in blockwise copy
// A desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
......@@ -345,7 +340,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
// B desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
......@@ -361,7 +356,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy
// E desc for destination in blockwise copy.
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
......@@ -381,7 +376,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
// Ds desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
......@@ -392,7 +387,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
Number<NumDTensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
......@@ -411,10 +405,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap = remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
......@@ -439,7 +431,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// check consistency of desc
// Check the consistency of descriptors.
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
......@@ -457,28 +449,26 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false;
}
// check tile size
// Check the tile size.
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
// Check gridwise gemm pipeline.
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// check block-to-E-tile
// Check block-to-E-tile.
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
// Check tensor size: cannot exceed 2GB.
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
......@@ -522,7 +512,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
// elementwise operations are not supported for A and B, left only for the API consistency
// Elementwise operations are not supported for A and B, arguments left only for the API
// consistency.
(void)a_element_op;
(void)b_element_op;
......@@ -543,7 +534,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
// Divide block work by [M, N].
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -555,7 +546,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return;
}
// HACK: this forces m/n_block_data_idx_on_grid into SGPR
// This forces m/n_block_data_idx_on_grid into SGPR.
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......@@ -564,13 +555,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, destination of blockwise copy.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, destination of blockwise copy.
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0PerBlock, MPerBlock, AK1>,
......@@ -588,7 +578,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0PerBlock, NPerBlock, BK1>,
......@@ -612,7 +601,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
......@@ -634,7 +622,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
......@@ -648,7 +636,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
......@@ -672,7 +659,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
// Shuffle C and write out.
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
......@@ -723,8 +710,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// Calculate the origin of thread output tensor on global memory.
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
......@@ -751,7 +737,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
// Shuffle: threadwise copy C from VGPR to LDS.
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
......@@ -783,7 +769,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
// A tuple of reference to C/Ds tensor descriptors.
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
......@@ -791,7 +777,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
// A tuple of reference to C/Ds grid buffers.
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
......@@ -799,7 +785,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
// A tuple of starting index of C/Ds blockwise copy.
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
......@@ -808,7 +794,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
},
Number<NumDTensor>{}));
// blockwise copy C/D/E between LDS and global
// Blockwise copy C/D/E between LDS and global.
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
......@@ -816,8 +802,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
......@@ -838,7 +823,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
// space filling curve for threadwise C in VGPR before shuffle
// Space filling curve for threadwise C in VGPR before shuffle.
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
......@@ -851,7 +836,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
// Space filling curve for shuffled blockwise C/D/E.
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
......@@ -865,20 +850,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
// Make sure it's safe to write to LDS.
block_sync_lds();
// each thread write its data from VGPR to LDS
// Each thread write its data from VGPR to LDS.
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
// Make sure it's safe to read from LDS.
block_sync_lds();
// each block copy its data from LDS to global
// Each block copy its data from LDS to global.
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
......@@ -890,13 +875,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
// Move on Ds.
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
// Move on E.
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
......@@ -942,19 +927,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
NRaw_{NRaw},
KRaw_{KRaw}
{
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E
if(CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
......@@ -978,19 +956,19 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// pointers
// Pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
// Tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
// Tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -1000,12 +978,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
// element-wise ops
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking vector load/store
// For checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment