"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0ab63ff6478b7cc6b5ae0d46c7c386d476cfa87f"
Commit b83c791e authored by Adam Osewski's avatar Adam Osewski
Browse files

Testing instances for store access pattern.

parent 7f4e416b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_direct_c_write_out.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto LoopSchedDefault = ck::LoopScheduler::Default;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//############################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| CThreadTransfer| LoopScheduler| Pipeline|
//############################| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| DstAccessOrder| DstVectorDim| DstScalarPerVector| | |
//############################| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | | | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v1>,
DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v1>,
DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 48, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v1>,
DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v1>,
DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v1>,
DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v1>,
DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v1>,
DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v1>
// #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// // pipeline v2, 1 wave
// ,
// DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v2>,
// DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v2>,
// DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 48, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v2>,
// DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v2>,
// DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v2>,
// DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v2>,
// DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v2>,
// DeviceGemm_Xdl_DirectCWriteOut< Row, Col, Row, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<0, 1, 2, 3, 4, 5, 6, 7>, 7, 8, LoopSchedDefault, PipelineVersion::v2>,
// #endif
// clang-format on
>;
void add_device_gemm_xdl_direct_c_write_out_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment