Commit b97c6876 authored by aska-0096's avatar aska-0096
Browse files

update ck_a8w8 library, update flush cache timing api

parent b3e5048f
......@@ -36,7 +36,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = I8;
using B0DataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using CShuffleDataType = F16;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
......@@ -78,8 +78,16 @@ struct MultiplyMultiply
__host__ __device__ constexpr void operator()<ck::half_t, int, ck::half_t, ck::half_t>(
ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const ck::half_t x0_f =
ck::type_convert<ck::half_t>(c) * d0* d1;
const ck::half_t x0_f = ck::type_convert<ck::half_t>(c) * d0 * d1;
e = x0_f;
}
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t, ck::half_t, ck::half_t>(
ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const ck::half_t x0_f = c * d0 * d1;
e = x0_f;
}
......@@ -115,12 +123,12 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
< Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
128, 128, 128,
64, 128, 256,
16, 16,
32, 32,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>;
// clang-format on
......@@ -216,6 +224,12 @@ int main(int argc, char* argv[])
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
break;
case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-25, 25});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 25});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 200});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 200});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
......@@ -271,7 +285,10 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50, true, 50});
hipStream_t stream;
hip_check_error(hipStreamCreate(&stream));
float ave_time = invoker.Run(argument, StreamConfig{stream, time_kernel, 0, 20, 50, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......
......@@ -4,6 +4,7 @@
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_ext.h>
#include <set>
#include <vector>
......@@ -42,8 +43,8 @@ struct RotatingMemWrapperMultiD
{
{
void* pADeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]),
size_a_,
hipMemcpyDeviceToDevice));
......@@ -52,8 +53,8 @@ struct RotatingMemWrapperMultiD
{
void* pBDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]),
size_b_,
hipMemcpyDeviceToDevice));
......@@ -65,8 +66,8 @@ struct RotatingMemWrapperMultiD
DsGridPointer ds_buffer;
static_for<0, NumDs, 1>{}([&](auto j) {
void* pDDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pDDeviceBuf),
static_cast<const void*>(p_ds_grids[0][j]),
size_ds_[j],
hipMemcpyDeviceToDevice));
......@@ -94,9 +95,8 @@ struct RotatingMemWrapperMultiD
void Print()
{
std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b;
static_for<0, NumDs, 1>{}([&](auto j) {
std::cout << ", size_d" <<j.value<<": "<< size_ds[j];
});
static_for<0, NumDs, 1>{}(
[&](auto j) { std::cout << ", size_d" << j.value << ": " << size_ds[j]; });
std::cout << ", rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapperMultiD()
......@@ -111,13 +111,35 @@ struct RotatingMemWrapperMultiD
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
try
{
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_a_grids[i])));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
try
{
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_b_grids[i])));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
static_for<0, NumDs, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
try
{
HIP_CHECK_ERROR(
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
});
}
}
......@@ -154,8 +176,8 @@ struct RotatingMemWrapper
{
{
void* pADeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]),
size_a_,
hipMemcpyDeviceToDevice));
......@@ -164,8 +186,8 @@ struct RotatingMemWrapper
{
void* pBDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]),
size_b_,
hipMemcpyDeviceToDevice));
......@@ -199,8 +221,23 @@ struct RotatingMemWrapper
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
try
{
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_a_grids[i])));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
try
{
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_b_grids[i])));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
}
}
}
......@@ -218,11 +255,11 @@ struct RotatingMemWrapper
inline void flush_icache()
{
hipDeviceProp_t deviceProps;
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
hip_check_error(hipGetLastError());
HIP_CHECK_ERROR(hipGetLastError());
}
// if TimePrePress == false, return time does not include preprocess's time
template <bool TimePreprocess,
......@@ -260,7 +297,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
for(int i = 0; i < stream_config.cold_niters_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
hip_check_error(hipGetLastError());
HIP_CHECK_ERROR(hipGetLastError());
}
const int nrepeat = stream_config.nrepeat_;
......@@ -280,45 +317,36 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#endif
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
HIP_CHECK_ERROR(hipEventCreate(&start));
HIP_CHECK_ERROR(hipEventCreate(&stop));
for(int i = 0; i < nrepeat; ++i)
{
if constexpr(!TimePreprocess)
{
preprocess();
}
// hipEvent_t start, stop;
preprocess();
// hip_check_error(hipEventCreate(&start));
// hip_check_error(hipEventCreate(&stop));
// hip_check_error(hipDeviceSynchronize());
// hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
if constexpr(TimePreprocess)
{
preprocess();
}
// run real kernel
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
hip_check_error(hipGetLastError());
hipExtLaunchKernelGGL(kernel,
grid_dim,
block_dim,
lds_byte,
stream_config.stream_id_,
start,
stop,
0,
gemm_args);
HIP_CHECK_ERROR(hipGetLastError());
// end real kernel
// hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
// hip_check_error(hipEventSynchronize(stop));
// float cur_time = 0;
// hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
// #if MEDIAN
// times.insert(cur_time);
// #else
// total_time += cur_time;
// #endif
HIP_CHECK_ERROR(hipEventRecord(stop, stream_config.stream_id_));
HIP_CHECK_ERROR(hipEventSynchronize(stop));
float cur_time = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
times.insert(cur_time);
#else
total_time += cur_time;
#endif
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
......@@ -329,15 +357,6 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
static_cast<const void*>(gemm_args.p_b_grid));
}
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
times.insert(cur_time);
#else
total_time += cur_time;
#endif
#if MEDIAN
auto mid = times.begin();
......@@ -353,24 +372,20 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return (*mid + *mid_next) / 2;
}
#else
// return total_time / nrepeat;
hipDeviceProp_t deviceProps;
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
float preprocess_offset = deviceProps.multiProcessorCount==80? 0.005 : 0.01;
return (total_time - preprocess_offset * nrepeat) / nrepeat;
return total_time / nrepeat;
#endif
}
else
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
hip_check_error(hipGetLastError());
HIP_CHECK_ERROR(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
hip_check_error(hipGetLastError());
HIP_CHECK_ERROR(hipGetLastError());
return 0;
#endif
......
......@@ -19,16 +19,15 @@ inline void hip_check_error(hipError_t x)
}
}
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" \
<< "hip_check_error.hpp" \
<< "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
......@@ -96,81 +96,81 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
MultiplyMultiply>>>& instances);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances(
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
......@@ -180,6 +180,7 @@ void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_i
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DsDataType,
typename ALayout,
typename BLayout,
typename CLayout>
......@@ -190,7 +191,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
CLayout,
ADataType,
BDataType,
Tuple<F32, F32>,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
......@@ -203,7 +204,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
CLayout,
ADataType,
BDataType,
Tuple<F32, F32>,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
......@@ -237,26 +238,26 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, bhalf_t>)
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances(
add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances(
add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances(
add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances(
add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
}
}
......
......@@ -9,18 +9,18 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
)
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_multiply_multiply_instance ${GEMM_MULTIPLY_MULTIPLY_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using I32 = int;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using MultiplyMultiply = element_wise::MultiplyMultiply;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>
// clang-format oI
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
// Memory friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>
// clang-format oI
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using I32 = int;
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using MultiplyMultiply = element_wise::MultiplyMultiply;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 32, 32, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 32, 32, 2, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 32, 32, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 32, 32, 1, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 32, 32, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 192, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 192, 128, 16, 16, 32, 32, 1, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 160, 256, 16, 16, 16, 16, 2, 5, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 96, 256, 16, 16, 16, 16, 2, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 224, 256, 16, 16, 16, 16, 1, 7, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 192, 256, 16, 16, 16, 16, 1, 6, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 160, 256, 16, 16, 16, 16, 1, 5, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 96, 256, 16, 16, 16, 16, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 64, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 192, 256, 16, 16, 16, 16, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>
// clang-format oI
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
// Memory friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F16, F16>, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>
// clang-format oI
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances<GemmDefault>{});
device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_instances<GemmDefault>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances<GemmKPadding>{});
device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_instances<GemmKPadding>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances<Interwave,
GemmDefault>{});
device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances<Intrawave,
GemmDefault>{});
device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances<Intrawave,
GemmKPadding>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances<Intrawave,
GemmKPadding>{});
device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances<Interwave,
GemmDefault>{});
}
} // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
Tuple<F16, F16>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances<Interwave,
GemmKPadding>{});
device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances<Interwave,
GemmKPadding>{});
}
} // namespace instance
......
......@@ -190,7 +190,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
{
// Seems like when performance measurement has bug when spiltK is large
// std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
std::vector<int> kbatch_list = {1, 2, 4};
std::vector<int> kbatch_list = {1, 2, 4, 8, 16};
if(KBatch > 0)
{
......@@ -251,8 +251,11 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString();
hipStream_t stream;
hip_check_error(hipStreamCreate(&stream));
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
StreamConfig{stream,
time_kernel,
0,
n_warmup,
......
......@@ -27,7 +27,7 @@ enum struct GemmDataType
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
INT8_INT8_BF16, // 8
INT8_INT8_F16, // 8
};
#define OP_NAME "gemm_multiply_multiply"
......@@ -40,7 +40,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->bf16)\n");
"comp f8; 8: int8->f16)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
......@@ -89,6 +89,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using I32 = int;
......@@ -165,10 +166,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::INT8_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
I8{}, I8{}, I8{}, I32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{});
I8{}, I8{}, I8{}, I32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment