"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "f57c0702f7ff7f190013e6636f8f0ad0a8141985"
Commit 7f4e416b authored by Adam Osewski's avatar Adam Osewski
Browse files

GridwiseGemm with direct c wirteout parameterized.

Added few template parameters deciding on store access pattern.
Right now it's designed to be warp raked.
parent d452452d
......@@ -12,7 +12,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_direct_c_write_out.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_direct_c_write_out_roofline.hpp"
// #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_direct_c_write_out.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -56,6 +57,9 @@ template <typename ALayout,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
typename CThreadTransferDstAccessOrder,
index_t CThreadTransferDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v2>
struct DeviceGemm_Xdl_DirectCWriteOut : public DeviceGemm<ALayout,
......@@ -381,6 +385,9 @@ struct DeviceGemm_Xdl_DirectCWriteOut : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CThreadTransferDstAccessOrder,
CThreadTransferDstVectorDim,
CThreadTransferDstScalarPerVector,
LoopSched,
PipelineVer>;
......
......@@ -17,7 +17,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp"
// #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out_roofline.hpp"
namespace ck {
namespace tensor_operation {
......
......@@ -751,8 +751,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
// N3 - mfma_instr.num_input_blks
// N4 - mfma_instr.group_size
// {M0, N0, 1, 1, 1, 4, 1, 4}
// constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment