Commit 7f4e416b authored by Adam Osewski's avatar Adam Osewski
Browse files

GridwiseGemm with direct c wirteout parameterized.

Added few template parameters deciding on store access pattern.
Right now it's designed to be warp raked.
parent d452452d
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_direct_c_write_out.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_direct_c_write_out_roofline.hpp"
// #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_direct_c_write_out.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -56,6 +57,9 @@ template <typename ALayout, ...@@ -56,6 +57,9 @@ template <typename ALayout,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
typename CThreadTransferDstAccessOrder,
index_t CThreadTransferDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v2> PipelineVersion PipelineVer = PipelineVersion::v2>
struct DeviceGemm_Xdl_DirectCWriteOut : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_DirectCWriteOut : public DeviceGemm<ALayout,
...@@ -381,6 +385,9 @@ struct DeviceGemm_Xdl_DirectCWriteOut : public DeviceGemm<ALayout, ...@@ -381,6 +385,9 @@ struct DeviceGemm_Xdl_DirectCWriteOut : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
false, false,
BBlockLdsExtraN, BBlockLdsExtraN,
CThreadTransferDstAccessOrder,
CThreadTransferDstVectorDim,
CThreadTransferDstScalarPerVector,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp" // #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out_roofline.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -751,8 +751,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -751,8 +751,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
// N3 - mfma_instr.num_input_blks // N3 - mfma_instr.num_input_blks
// N4 - mfma_instr.group_size // N4 - mfma_instr.group_size
// {M0, N0, 1, 1, 1, 4, 1, 4} // {M0, N0, 1, 1, 1, 4, 1, 4}
// constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 is only used to get lengths // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment