"docs/vscode:/vscode.git/clone" did not exist on "328e0d20a7b996f9bdb04180457eb08c1b42a76e"
Commit f9f2cdf9 authored by root's avatar root
Browse files

Do not clear cthread buffer if needed.

- Add output stream operators for LoopSched and PiplineVer
parent d14aaa52
...@@ -50,3 +50,15 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -50,3 +50,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
} }
} // namespace ck } // namespace ck
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
{
switch(p)
{
case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break;
case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break;
case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break;
default: os << "";
}
return os;
}
...@@ -155,7 +155,8 @@ struct GridwiseGemmPipeline_v1<2> ...@@ -155,7 +155,8 @@ struct GridwiseGemmPipeline_v1<2>
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop,
bool clear_c_thread_buf = true)
{ {
// preload data into LDS // preload data into LDS
{ {
...@@ -173,6 +174,7 @@ struct GridwiseGemmPipeline_v1<2> ...@@ -173,6 +174,7 @@ struct GridwiseGemmPipeline_v1<2>
} }
// Initialize C // Initialize C
if(clear_c_thread_buf)
c_thread_buf.Clear(); c_thread_buf.Clear();
// main body // main body
...@@ -298,7 +300,8 @@ struct GridwiseGemmPipelineInterwave_v1<1> ...@@ -298,7 +300,8 @@ struct GridwiseGemmPipelineInterwave_v1<1>
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop,
bool clear_c_thread_buf = true)
{ {
// preload data into LDS // preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
...@@ -308,6 +311,7 @@ struct GridwiseGemmPipelineInterwave_v1<1> ...@@ -308,6 +311,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C // Initialize C
if(clear_c_thread_buf)
c_thread_buf.Clear(); c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
......
...@@ -49,7 +49,8 @@ struct GridwiseGemmPipeline_v2 ...@@ -49,7 +49,8 @@ struct GridwiseGemmPipeline_v2
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop,
bool clear_c_thread_buf = true)
{ {
// global read 0 // global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
...@@ -60,6 +61,7 @@ struct GridwiseGemmPipeline_v2 ...@@ -60,6 +61,7 @@ struct GridwiseGemmPipeline_v2
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C // Initialize C
if(clear_c_thread_buf)
c_thread_buf.Clear(); c_thread_buf.Clear();
// LDS write 0 // LDS write 0
......
...@@ -68,7 +68,8 @@ struct GridwiseGemmPipeline_v4<1> ...@@ -68,7 +68,8 @@ struct GridwiseGemmPipeline_v4<1>
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop,
bool clear_c_thread_buf = true)
{ {
static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1); static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1);
auto& a_block_buf = a_block_bufs.At(I0); auto& a_block_buf = a_block_bufs.At(I0);
...@@ -81,6 +82,7 @@ struct GridwiseGemmPipeline_v4<1> ...@@ -81,6 +82,7 @@ struct GridwiseGemmPipeline_v4<1>
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C // Initialize C
if(clear_c_thread_buf)
c_thread_buf.Clear(); c_thread_buf.Clear();
// main body // main body
...@@ -164,7 +166,8 @@ struct GridwiseGemmPipeline_v4<2> ...@@ -164,7 +166,8 @@ struct GridwiseGemmPipeline_v4<2>
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop,
bool clear_c_thread_buf = true)
{ {
static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2); static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2);
auto& a_block_buf1 = a_block_bufs.At(I0); auto& a_block_buf1 = a_block_bufs.At(I0);
...@@ -179,6 +182,7 @@ struct GridwiseGemmPipeline_v4<2> ...@@ -179,6 +182,7 @@ struct GridwiseGemmPipeline_v4<2>
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C // Initialize C
if(clear_c_thread_buf)
c_thread_buf.Clear(); c_thread_buf.Clear();
// main body // main body
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <ostream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
...@@ -24,3 +26,14 @@ constexpr LoopScheduler make_default_loop_scheduler() ...@@ -24,3 +26,14 @@ constexpr LoopScheduler make_default_loop_scheduler()
} }
} // namespace ck } // namespace ck
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
{
case ck::LoopScheduler::Default: os << "Default"; break;
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
default: os << "";
}
return os;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment