Commit c5893dc6 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add pipeline version to GroupedGEMM device op type string.

parent aeaccf69
...@@ -736,7 +736,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -736,7 +736,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
<< CShuffleMXdlPerWavePerShuffle << ", " << CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle << ", "
<< ABlockTransferThreadClusterLengths_K0_M_K1{} << ", " << ABlockTransferThreadClusterLengths_K0_M_K1{} << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec) << ", "
<< PipelineVer{}
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <ostream>
#include <string>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...@@ -42,4 +44,20 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -42,4 +44,20 @@ constexpr auto GridwiseGemmPipeline_Selector()
} }
} }
std::string getPipelineVersionString(const PipelineVersion& pv)
{
switch(pv)
{
case PipelineVersion::v1: return "PipelineVersion::v1";
case PipelineVersion::v2: return "PipelineVersion::v2";
default: return "Unrecognized pipeline version!";
}
}
} // namespace ck } // namespace ck
std::ostream& operator<<(std::ostream& os, const PipelineVersion pv)
{
os << getPipelineVersionString(pv);
return os;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment