Commit 9205784f authored by Adam Osewski's avatar Adam Osewski
Browse files

Split instances into multiple files.

parent 7a5d11c9
......@@ -31,7 +31,33 @@ void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregu
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1_interwave(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v2(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
......@@ -90,7 +116,11 @@ struct DeviceOperationInstanceFactory<
is_same_v<ELayout, Row>)
{
#if defined(CK_ENABLE_FP16)
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1(
op_ptrs);
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1_interwave(
op_ptrs);
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v2(
op_ptrs);
#endif
}
......
add_instance_library(device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v1.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v1_interwave.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v2.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using GemmInstances =
device_ggemm_md_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances<
PipelineVersion::v1>;
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances, GemmInstances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using GemmInstances =
device_ggemm_md_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances<
PipelineVersion::v1,
LoopScheduler::Interwave>;
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1_interwave(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances, GemmInstances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using GemmInstances =
device_ggemm_md_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances<
PipelineVersion::v2>;
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v2(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances, GemmInstances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -303,8 +303,9 @@ bool profile_ggemm_multid_splitk(int do_verification,
}
}
std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
std::cout << "Instance: " << gemm_name << " kbatch: " << kbatch_curr
<< " verification " << (instance_pass ? "SUCCEED" : "FAILED")
<< std::endl;
pass = pass && instance_pass;
// std::cout << ">>>>>CPU verification end!" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment