Unverified Commit 8aba2724 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Add bf16 and int8 wmma gemms for Navi3x and Navi4x. (#1671)

* add bf16 gemms for gfx11/gfx12

* reduce the input values in test_gemm

* add int8 wmma gemm instances for gfx11/gfx12

* add example gemm_wmma_int8

* fix bug in gemm_wmma_int8 test

* increase bf16 gemm test tolerance

* update the dates and clean-up commented-out instances
parent 754adc70
......@@ -83,3 +83,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
add_example_executable(example_gemm_wmma_bf16 gemm_wmma_bf16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16)
add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_int8)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::bhalf_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
2, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
using ADataType = int8_t;
using BDataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using CDataType = int8_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
2, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
2,
2,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
......@@ -13,6 +13,11 @@ namespace ck {
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
......@@ -99,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11__) || defined(__gfx12__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
......@@ -261,10 +266,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
......
......@@ -180,6 +180,58 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<CDataType, ck::bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_int8_int8_int8_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_int8_int8_int8_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_int8_int8_int8_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_int8_int8_int8_km_nk_mn_instances(op_ptrs);
}
}
#endif
#endif
#ifdef CK_USE_XDL
......
......@@ -28,6 +28,46 @@ void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_int8_int8_int8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_int8_int8_int8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_int8_int8_int8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_int8_int8_int8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -206,7 +206,7 @@ typename std::enable_if<
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double rtol = 1e-1,
double atol = 1e-3)
{
if(out.size() != ref.size())
......
......@@ -2,9 +2,7 @@ set(GEMM_INSTANCES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
......@@ -21,9 +19,6 @@ list(APPEND GEMM_INSTANCES
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
)
list(APPEND GEMM_INSTANCES
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
......@@ -78,9 +73,6 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
)
list(APPEND GEMM_INSTANCES
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
......@@ -92,15 +84,11 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp
......@@ -109,14 +97,19 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp)
device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instance.cpp
device_gemm_wmma_int8_int8_int8_mk_kn_mn_instance.cpp
device_gemm_wmma_int8_int8_int8_mk_nk_mn_instance.cpp
device_gemm_wmma_int8_int8_int8_km_kn_mn_instance.cpp
device_gemm_wmma_int8_int8_int8_km_nk_mn_instance.cpp)
add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
/* Prefetch 2, consume enormous vgpr resource*/
// 8 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
/* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/
// 8 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 2>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
void add_device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemm_wmma_bf16_bf16_bf16_km_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
/* Prefetch 2, consume enormous vgpr resource*/
// 8 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
/* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/
// 8 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
void add_device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemm_wmma_bf16_bf16_bf16_km_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
/* Prefetch 2, consume enormous vgpr resource*/
// 8 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
/* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/
// 8 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 2>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
void add_device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemm_wmma_bf16_bf16_bf16_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
/* Prefetch 2, consume enormous vgpr resource*/
// 8 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
/* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/
// 8 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
void add_device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemm_wmma_bf16_bf16_bf16_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_wmma_int8_int8_int8_km_kn_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
/* Prefetch 2, consume enormous vgpr resource*/
// 8 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
/* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/
// 8 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 2>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGemmWmma_CShuffle< Col, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
void add_device_gemm_wmma_int8_int8_int8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, I8, I8, I8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(instances, device_gemm_wmma_int8_int8_int8_km_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_wmma_int8_int8_int8_km_nk_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
/* Prefetch 2, consume enormous vgpr resource*/
// 8 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
/* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/
// 8 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGemmWmma_CShuffle< Col, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
void add_device_gemm_wmma_int8_int8_int8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, I8, I8, I8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(instances, device_gemm_wmma_int8_int8_int8_km_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_wmma_int8_int8_int8_mk_kn_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
/* Prefetch 2, consume enormous vgpr resource*/
// 8 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
/* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/
// 8 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 2>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGemmWmma_CShuffle< Row, Row, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
void add_device_gemm_wmma_int8_int8_int8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, I8, I8, I8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(instances, device_gemm_wmma_int8_int8_int8_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_wmma_int8_int8_int8_mk_nk_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
/* Prefetch 2, consume enormous vgpr resource*/
// 8 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
/* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/
// 8 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>,
// 4 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>,
// 2 Waves
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
// 1 Wave
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGemmWmma_CShuffle< Row, Col, Row, I8, I8, I8, I32, I8, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
void add_device_gemm_wmma_int8_int8_int8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, I8, I8, I8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(instances, device_gemm_wmma_int8_int8_int8_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment