Commit bbe74503 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 8b76b832 f53ede26
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Layout(A, B, C) = [Col, Col, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances<Mul_Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Layout(A, B, C) = [Row, Row, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances<Mul_Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Layout(A, B, C) = [Row, Col, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances<Mul_Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_quantization_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <typename OutElementOp, LoopScheduler GemmLoopScheduler, PipelineVersion GemmPipeline>
using device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 4, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>
// clang-format on
>;
template <typename OutElementOp, LoopScheduler GemmLoopScheduler, PipelineVersion GemmPipeline>
using device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 4, 16, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 4, 16, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>
// clang-format on
>;
template <typename OutElementOp, LoopScheduler GemmLoopScheduler, PipelineVersion GemmPipeline>
using device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>
// clang-format on
>;
template <typename OutElementOp, LoopScheduler GemmLoopScheduler, PipelineVersion GemmPipeline>
using device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple<
// clang-format off
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances<Mul_Clamp,
LoopScheduler::Default,
PipelineVersion::v1>{});
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances<Mul_Clamp,
LoopScheduler::Interwave,
PipelineVersion::v1>{});
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances<Mul_Clamp,
LoopScheduler::Default,
PipelineVersion::v2>{});
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Layout(A, B, C) = [Col, Col, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances<Mul_Clamp,
LoopScheduler::Default,
PipelineVersion::v1>{});
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances<Mul_Clamp,
LoopScheduler::Interwave,
PipelineVersion::v1>{});
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances<Mul_Clamp,
LoopScheduler::Default,
PipelineVersion::v2>{});
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Layout(A, B, C) = [Row, Row, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances<Mul_Clamp,
LoopScheduler::Default,
PipelineVersion::v1>{});
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances<Mul_Clamp,
LoopScheduler::Interwave,
PipelineVersion::v1>{});
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances<Mul_Clamp,
LoopScheduler::Default,
PipelineVersion::v2>{});
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Layout(A, B, C) = [Row, Col, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances<Mul_Clamp,
LoopScheduler::Default,
PipelineVersion::v1>{});
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances<Mul_Clamp,
LoopScheduler::Interwave,
PipelineVersion::v1>{});
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
add_device_operation_instances(
instances,
device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances<Mul_Clamp,
LoopScheduler::Default,
PipelineVersion::v2>{});
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Empty_Tuple = ck::Tuple<>;
using Row_Row_Tuple = ck::Tuple<Row, Row>;
using Col_Col_Tuple = ck::Tuple<Col, Col>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>;
using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Relu>;
using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<PassThrough>;
using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Relu>;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -72,8 +72,8 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -72,8 +72,8 @@ bool profile_gemm_splitk_impl(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
...@@ -94,7 +94,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -94,7 +94,7 @@ bool profile_gemm_splitk_impl(int do_verification,
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_device_buf.SetZero();
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<ALayout, using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
BLayout, BLayout,
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
...@@ -39,7 +40,8 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -39,7 +40,8 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs) const std::vector<int>& StrideCs,
int kbatch = 1)
{ {
bool pass = true; bool pass = true;
...@@ -96,8 +98,6 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -96,8 +98,6 @@ bool profile_grouped_gemm_impl(int do_verification,
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread); a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
} }
c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
} }
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -132,13 +132,12 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -132,13 +132,12 @@ bool profile_grouped_gemm_impl(int do_verification,
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back( b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>( c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->SetZero();
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
...@@ -197,6 +196,28 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -197,6 +196,28 @@ bool profile_grouped_gemm_impl(int do_verification,
{ {
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
if(kbatch > 1)
{
using DeviceOpSplitK =
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
{
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch);
}
}
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
......
...@@ -190,9 +190,9 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -190,9 +190,9 @@ bool profile_groupnorm_impl(int do_verification,
if(time_kernel) if(time_kernel)
{ {
LogRange(std::cout << "length = ", length, ",") << ", "; LogRange(std::cout << "length = ", length, ",") << std::endl;
std::cout << "num_kernel = " << num_kernel << ", best perf = " << best_avg_time << " ms, " std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_gb_per_sec << " GB/s, " << best_instance_name << std::endl; << best_instance_name << std::endl;
} }
if(num_kernel == 0) if(num_kernel == 0)
......
...@@ -52,20 +52,24 @@ std::vector<int> argToIntArray(char* input) ...@@ -52,20 +52,24 @@ std::vector<int> argToIntArray(char* input)
int profile_grouped_gemm(int argc, char* argv[]) int profile_grouped_gemm(int argc, char* argv[])
{ {
if(!(argc == 14)) if(argc < 14)
{ {
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); std::cout
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); << " 1: A[m, k] * B[n, k] = C[m, n];\n"
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); << " 2: A[k, m] * B[k, n] = C[m, n];\n"
printf("arg4: verification (0: no; 1: yes)\n"); << " 3: A[k, m] * B[n, k] = C[m, n])\n"
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); << "arg4: verification (0: no; 1: yes)\n"
printf("arg6: print tensor value (0: no; 1: yes)\n"); << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
printf("arg7: time kernel (0=n0, 1=yes)\n"); << "arg6: print tensor value (0: no; 1: yes)\n"
printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " << "arg7: time kernel (0=n0, 1=yes)\n"
"64,64 64,64 128,128)\n"); << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 4)\n"
<< std::endl;
exit(1); exit(1);
} }
...@@ -83,6 +87,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -83,6 +87,7 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto StrideAs = argToIntArray(argv[11]); const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]); const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
...@@ -101,7 +106,8 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -101,7 +106,8 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks, Ks,
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs); StrideCs,
kbatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
...@@ -120,7 +126,8 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -120,7 +126,8 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks, Ks,
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs); StrideCs,
kbatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
...@@ -139,7 +146,8 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -139,7 +146,8 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks, Ks,
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs); StrideCs,
kbatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
...@@ -158,7 +166,8 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -158,7 +166,8 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks, Ks,
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs); StrideCs,
kbatch);
} }
else else
{ {
......
...@@ -8,12 +8,12 @@ MY_PROJECT_SOURCE=$1 ...@@ -8,12 +8,12 @@ MY_PROJECT_SOURCE=$1
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a" \ -D GPU_TARGETS="gfx908;gfx90a;gfx940" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
#-D AMDGPU_TARGETS=gfx90a;gfx908
...@@ -11,9 +11,8 @@ cmake ...@@ -11,9 +11,8 @@ cmake
-D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_CXX_FLAGS="-O3" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=OFF \ -D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx908;gfx90a" \ -D GPU_TARGETS="gfx908;gfx90a;gfx940" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
#-D AMDGPU_TARGETS=gfx90a;gfx908
...@@ -43,7 +43,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -43,7 +43,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
DataType, DataType,
DataType, DataType,
DataType>(true, // do_verification DataType>(true, // do_verification
1, // init_method integer value 1, // init_method: integer value
false, // do_log false, // do_log
false, // time_kernel false, // time_kernel
param, param,
...@@ -60,9 +60,9 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight, KernelTypes); ...@@ -60,9 +60,9 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight, KernelTypes);
TYPED_TEST(TestGroupedConvndBwdWeight, Test1D) TYPED_TEST(TestGroupedConvndBwdWeight, Test1D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
this->conv_params.push_back({1, 4, 64, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
this->template Run<1>(); this->template Run<1>();
} }
...@@ -70,11 +70,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D) ...@@ -70,11 +70,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back( this->conv_params.push_back(
{2, 4, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 4, 8, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{2, 4, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>(); this->template Run<2>();
} }
...@@ -82,10 +82,10 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D) ...@@ -82,10 +82,10 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back( this->conv_params.push_back(
{3, 4, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 4, 8, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 4, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>(); this->template Run<3>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment