Commit 25205ea5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Uncomment previously commented code for debug purposes.

parent 089f610b
......@@ -17,45 +17,45 @@ namespace tensor_operation {
namespace device {
namespace instance {
// void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
// std::vector<std::unique_ptr<
// DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
// instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
// void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(
// std::vector<std::unique_ptr<
// DeviceGemmSplitK<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
// instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
// void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<
// DeviceGemmSplitK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
// instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
// void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
// std::vector<std::unique_ptr<
// DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
// instances);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
// void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
// std::vector<std::unique_ptr<
// DeviceGemmSplitK<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
// instances);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
// void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
// std::vector<std::unique_ptr<
// DeviceGemmSplitK<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
// instances);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
// void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
// std::vector<std::unique_ptr<
// DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
// instances);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
template <typename ADataType,
typename BDataType,
......@@ -91,26 +91,26 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
{
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
// }
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
......@@ -118,7 +118,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
// add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
......@@ -128,12 +128,12 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
// add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
// add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment