Commit 885ff00a authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add enumerator for new pipeline version (optimized pipeline v1)

parent 833a31bb
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...@@ -11,9 +13,22 @@ namespace ck { ...@@ -11,9 +13,22 @@ namespace ck {
enum struct PipelineVersion enum struct PipelineVersion
{ {
v1, v1,
v1_opt0,
v2, v2,
}; };
inline std::ostream& operator<<(std::ostream& stream, PipelineVersion version)
{
switch(version)
{
case PipelineVersion::v1: return stream << "v1";
case PipelineVersion::v1_opt0: return stream << "v1_opt0";
case PipelineVersion::v2: return stream << "v2";
}
__builtin_unreachable();
}
template <PipelineVersion PipelineVer, template <PipelineVersion PipelineVer,
index_t NumPrefetch = 1, index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
...@@ -30,6 +45,10 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -30,6 +45,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
return GridwiseGemmPipelineInterwave_v1<NumPrefetch>{}; return GridwiseGemmPipelineInterwave_v1<NumPrefetch>{};
} }
} }
else if constexpr(PipelineVer == PipelineVersion::v1_opt0)
{
return GridwiseGemmPipeline_v1<NumPrefetch, 0>{};
}
else if constexpr(PipelineVer == PipelineVersion::v2) else if constexpr(PipelineVer == PipelineVersion::v2)
{ {
return GridwiseGemmPipeline_v2{}; return GridwiseGemmPipeline_v2{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment