Commit 5f4c1ddb authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Clean the code and comments

parent 0661e8d2
...@@ -44,7 +44,7 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -44,7 +44,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed # FIXME: re-enable this example as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
...@@ -58,9 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) ...@@ -58,9 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
if(GPU_TARGETS MATCHES "gfx90a") if(GPU_TARGETS MATCHES "gfx90a")
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp) add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
if(result EQUAL 0) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
add_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
endif()
endif() endif()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
......
...@@ -6,26 +6,12 @@ ...@@ -6,26 +6,12 @@
#include "common.hpp" #include "common.hpp"
#define USING_DIRECT_LOADS 1 #define USING_DIRECT_LOADS 1
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#if USING_DIRECT_LOADS #if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else #else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif #endif
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
using F32 = float; using F32 = float;
using ADataType = F32; using ADataType = F32;
......
...@@ -67,17 +67,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -67,17 +67,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"The number of threads cannot be less than the number of elements in " "The number of threads cannot be less than the number of elements in "
"thread cluster lengths."); "thread cluster lengths.");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or const auto thread_cluster_idx =
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx; const auto thread_data_idx_begin = thread_cluster_idx;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
}
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -103,11 +99,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -103,11 +99,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(ThreadGroup::GetNumOfThread() != thread_cluster_desc_.GetElementSize() &&
ThreadGroup::GetThreadId() >= thread_cluster_desc_.GetElementSize())
{
return;
}
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global, static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer."); "Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
...@@ -120,21 +111,19 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -120,21 +111,19 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"DstBuffer and DstData data types must be consistent."); "DstBuffer and DstData data types must be consistent.");
constexpr auto dst_access_lengths = thread_slice_lengths; constexpr auto dst_access_lengths = thread_slice_lengths;
constexpr auto dst_dim_access_order = Sequence<0, 1, 2>{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
const auto dst_forward_steps = generate_steps(dst_desc, 1); const auto dst_forward_steps = generate_steps(dst_desc, 1);
const auto dst_backward_steps = generate_steps(dst_desc, -1); const auto dst_backward_steps = generate_steps(dst_desc, -1);
const auto src_forward_steps = generate_steps(src_desc, 1); const auto src_forward_steps = generate_steps(src_desc, 1);
const auto src_backward_steps = generate_steps(src_desc, -1); const auto src_backward_steps = generate_steps(src_desc, -1);
// loop over tensor and copy // Loop over the destination block and copy data.
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) { static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
const auto src_offset = src_coord_.GetOffset(); const auto src_offset = src_coord_.GetOffset();
const auto dst_offset = dst_coord_.GetOffset(); const auto dst_offset = dst_coord_.GetOffset();
// Check if src data is not in the logic padding area.
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
...@@ -145,11 +134,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -145,11 +134,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
StaticallyIndexedArray<bool, nDim> move_on_dim_; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
}); });
}); });
...@@ -157,7 +145,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -157,7 +145,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
} }
(); ();
// judge move forward or move backward // Decide whether to move forward or backward.
constexpr auto forward_sweep = [&]() { constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_; StaticallyIndexedArray<bool, nDim> forward_sweep_;
...@@ -167,7 +155,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -167,7 +155,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) { static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
}); });
forward_sweep_(i) = tmp % 2 == 0; forward_sweep_(i) = tmp % 2 == 0;
...@@ -181,33 +169,26 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -181,33 +169,26 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_tensor_coordinate( move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[dst_dim_access_order[i]]);
} }
else else
{ {
move_tensor_coordinate( move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[dst_dim_access_order[i]]);
} }
} }
}); });
}); });
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow(dst_desc); ResetDstSliceWindow(dst_desc);
} }
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or src_slice_origin_ = src_slice_origin_ + step;
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
{
src_slice_origin_ = src_slice_origin_ + step;
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
}
} }
template <typename DescType> template <typename DescType>
......
...@@ -191,7 +191,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -191,7 +191,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
} }
} }
// polymorphic
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
...@@ -206,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -206,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check vector load/store // Check vector load/store.
{ {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A // Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2) if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
...@@ -221,7 +220,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -221,7 +220,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
} }
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1) else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{ {
return false; return false;
...@@ -232,7 +230,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -232,7 +230,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check vector load of B // Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2) if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
...@@ -242,7 +240,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -242,7 +240,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
} }
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1) else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{ {
return false; return false;
...@@ -253,8 +250,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -253,8 +250,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check vector load of Ds // Check vector load of Ds.
// only support RowMajor for now // For now, only the RowMajor layout is supported.
bool all_valid = true; bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -271,8 +268,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -271,8 +268,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return false; return false;
} }
// check vector store of E // Check vector load of E.
// only support RowMajor for now // For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>) if constexpr(is_same_v<ELayout, Row>)
{ {
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
...@@ -293,7 +290,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -293,7 +290,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
arg.block_2_etile_map_); arg.block_2_etile_map_);
} }
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
...@@ -332,7 +328,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -332,7 +328,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -365,13 +360,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -365,13 +360,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
cde_element_op); cde_element_op);
} }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
...@@ -118,7 +118,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -118,7 +118,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
...@@ -186,7 +185,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -186,7 +185,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
} }
} }
// polymorphic
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
...@@ -201,12 +199,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -201,12 +199,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false; return false;
} }
// check vector load/store // Check vector load/store.
{ {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A // Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2) if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
...@@ -216,7 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -216,7 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
} }
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1) else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{ {
return false; return false;
...@@ -227,7 +224,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -227,7 +224,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false; return false;
} }
// check vector load of B // Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2) if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
...@@ -237,7 +234,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -237,7 +234,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
} }
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1) else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{ {
return false; return false;
...@@ -248,8 +244,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -248,8 +244,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return false; return false;
} }
// check vector store of E // Check vector load of E.
// only support RowMajor for now // For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>) if constexpr(is_same_v<ELayout, Row>)
{ {
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
...@@ -270,7 +266,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -270,7 +266,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
arg.block_2_etile_map_); arg.block_2_etile_map_);
} }
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
...@@ -310,7 +305,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -310,7 +305,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -344,13 +338,11 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -344,13 +338,11 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
cde_element_op); cde_element_op);
} }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment