Commit 600a9870 authored by Alan Turner's avatar Alan Turner
Browse files

Clean up

parent 497c30e0
...@@ -37,42 +37,6 @@ struct BaseInvoker ...@@ -37,42 +37,6 @@ struct BaseInvoker
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
struct BaseParameters
{
BaseParameters() = default;
BaseParameters(const BaseParameters&) = default;
BaseParameters& operator=(const BaseParameters&) = default;
virtual void SetAElementOp(const std::string&) {}
virtual void SetBElementOp(const std::string&) {}
virtual void SetCDEElementOp(const std::string&) {}
virtual void SetDsLayout(const std::string&) {}
virtual void SetDsDataType(const std::string&) {}
virtual void SetGemmSpec(const index_t, const index_t, const index_t) {}
virtual index_t GetGridSize(const index_t, const index_t)
{
return 0;
}
virtual index_t GetBlockSize()
{
return 0;
}
virtual std::string GetParametersString()
{
return "";
}
virtual ~BaseParameters() {}
};
struct BaseOperator struct BaseOperator
{ {
BaseOperator() = default; BaseOperator() = default;
......
...@@ -51,26 +51,6 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -51,26 +51,6 @@ struct DeviceGemmMultipleD : public BaseOperator
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::unique_ptr<BaseParameters> MakeParametersPointer()
{
return std::make_unique<BaseParameters>(BaseParameters{});
}
virtual index_t GetBlockSize() const
{
return 0;
}
virtual index_t GetMPerBlock() const
{
return 0;
}
virtual index_t GetNPerBlock() const
{
return 0;
}
}; };
} // namespace device } // namespace device
......
...@@ -699,195 +699,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -699,195 +699,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str(); return str.str();
} }
struct Parameters : BaseParameters
{
template <class S>
static std::string GetSequenceString(S s)
{
auto str = std::stringstream();
str << "ck::Sequence<";
auto size = s.Size();
for(int i = 0; i < size; ++i)
{
str << s.At(i);
if(i < size - 1)
str << ",";
}
str << ">";
return str.str();
}
template <class T>
static std::string GetTypeString(T)
{
return "";
}
template <>
static std::string GetTypeString<float>(float)
{
return "float";
}
template <>
static std::string GetTypeString<ck::half_t>(ck::half_t)
{
return "ck::half_t";
}
template <>
static std::string
GetTypeString<tensor_layout::gemm::RowMajor>(tensor_layout::gemm::RowMajor)
{
return "ck::tensor_layout::gemm::RowMajor";
}
template <>
static std::string
GetTypeString<tensor_layout::gemm::ColumnMajor>(tensor_layout::gemm::ColumnMajor)
{
return "ck::tensor_layout::gemm::ColumnMajor";
}
template <class T>
static std::string GetTupleString(T t)
{
auto str = std::stringstream();
str << "ck::Tuple<";
static_for<0, t.Size(), 1>{}([&](auto i) {
str << GetTypeString(t.At(i));
if(i < t.Size() - 1)
str << ",";
});
str << ">";
return str.str();
}
template <>
static std::string GetTupleString<Tuple<>>(Tuple<>)
{
return "ck::Tuple<>";
}
void SetAElementOp(const std::string& s) override { a_elementwise_op = s; }
void SetBElementOp(const std::string& s) override { b_elementwise_op = s; }
void SetCDEElementOp(const std::string& s) override { cde_elementwise_op = s; }
void SetDsLayout(const std::string& s) override { ds_layout = s; }
void SetDsDataType(const std::string& s) override { ds_data_type = s; }
void SetGemmSpec(const index_t m, const index_t n, const index_t k) override
{
std::string spec = "";
if(math::integer_divide_ceil(m, MPerBlock) * MPerBlock - m != 0)
spec += "M";
if(math::integer_divide_ceil(n, NPerBlock) * NPerBlock - n != 0)
spec += "N";
if(math::integer_divide_ceil(k, KPerBlock) * KPerBlock - k != 0)
spec += "K";
if(spec == "")
gemm_spec = "ck::tensor_operation::device::GemmSpecialization::Default";
else
gemm_spec = "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
index_t GetGridSize(const index_t m, const index_t n) override
{
return math::integer_divide_ceil(m, MPerBlock) *
math::integer_divide_ceil(n, NPerBlock);
}
index_t GetBlockSize() override { return BlockSize; }
std::string GetParametersString() override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "ck::LoopScheduler::Default"},
{LoopScheduler::Interwave, "ck::LoopScheduler::Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "ck::PipelineVersion::v1"},
{PipelineVersion::v2, "ck::PipelineVersion::v2"}};
// clang-format off
str << "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle"
<< "<"
<< GetTypeString(ALayout{}) << ", "
<< GetTypeString(BLayout{}) << ", "
<< ds_layout << ", "
<< GetTypeString(ELayout{}) << ", "
<< GetTypeString(ADataType{}) << ", "
<< GetTypeString(BDataType{}) << ", "
<< GetTypeString(AccDataType{}) << ", "
<< GetTypeString(CShuffleDataType{}) << ", "
<< ds_data_type << ", "
<< GetTypeString(EDataType{}) << ", "
<< a_elementwise_op << ", "
<< b_elementwise_op << ", "
<< cde_elementwise_op << ", "
<< gemm_spec << ", "
<< NumGemmKPrefetchStage << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< GetSequenceString(ABlockTransferThreadClusterLengths_AK0_M_AK1{}) << ", "
<< GetSequenceString(ABlockTransferThreadClusterArrangeOrder{}) << ", "
<< GetSequenceString(ABlockTransferSrcAccessOrder{}) << ", "
<< ABlockTransferSrcVectorDim << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_AK1 << ", "
<< ABlockLdsExtraM << ", "
<< GetSequenceString(BBlockTransferThreadClusterLengths_BK0_N_BK1{}) << ", "
<< GetSequenceString(BBlockTransferThreadClusterArrangeOrder{}) << ", "
<< GetSequenceString(BBlockTransferSrcAccessOrder{}) << ", "
<< BBlockTransferSrcVectorDim << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_BK1 << ", "
<< BBlockLdsExtraN << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< GetSequenceString(CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}) << ", "
<< CDEBlockTransferScalarPerVector_NPerBlock << ", "
<< LoopSchedToString[LoopSched] << ", "
<< PipelineVersionToString[PipelineVer]
<< ">";
// clang-format on
return str.str();
}
std::string a_elementwise_op = "ck::tensor_operation::element_wise::PassThrough";
std::string b_elementwise_op = "ck::tensor_operation::element_wise::PassThrough";
std::string cde_elementwise_op = "ck::tensor_operation::element_wise::PassThrough";
std::string ds_layout = "ck::Tuple<>";
std::string ds_data_type = "ck::Tuple<>";
std::string gemm_spec = "ck::tensor_operation::device::GemmSpecialization::" +
getGemmSpecializationString(GemmSpec);
};
std::unique_ptr<BaseParameters> MakeParametersPointer() override
{
return std::make_unique<Parameters>(Parameters{});
}
index_t GetBlockSize() const override { return BlockSize; }
index_t GetMPerBlock() const override { return MPerBlock; }
index_t GetNPerBlock() const override { return NPerBlock; }
template <class ADesc, class BDesc, class DsDesc, class EDesc> template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor struct Descriptor
{ {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp" #include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
......
...@@ -21,6 +21,7 @@ ENDFOREACH() ...@@ -21,6 +21,7 @@ ENDFOREACH()
add_library(device_operations STATIC ${CK_DEVICE_INSTANCES}) add_library(device_operations STATIC ${CK_DEVICE_INSTANCES})
add_library(composablekernels::device_operations ALIAS device_operations) add_library(composablekernels::device_operations ALIAS device_operations)
set(DEV_OPS_INC_DIRS set(DEV_OPS_INC_DIRS
${PROJECT_SOURCE_DIR}/include/ck/ ${PROJECT_SOURCE_DIR}/include/ck/
${PROJECT_SOURCE_DIR}/library/include/ck/ ${PROJECT_SOURCE_DIR}/library/include/ck/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment