Commit 55cdf2b9 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/layernorm_fusion

parents 4b59b5c9 b098b71b
...@@ -23,6 +23,130 @@ ...@@ -23,6 +23,130 @@
namespace ck { namespace ck {
namespace utils { namespace utils {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int numberOfAccumulations = 1)
{
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
is_same_v<ComputeDataType, int>)
{
return 0;
}
else
{
compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
}
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
is_same_v<OutDataType, int>)
{
return 0;
}
else
{
output_error = std::pow(2, -NumericUtils<OutDataType>::mant) * 0.5;
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
is_same_v<AccDataType, int>)
{
return 0;
}
else
{
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
}
return std::max(acc_error, midway_error);
}
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1)
{
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
is_same_v<ComputeDataType, int>)
{
return 0;
}
else
{
compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
}
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
is_same_v<OutDataType, int>)
{
return 0;
}
else
{
output_error = std::pow(2, expo - NumericUtils<OutDataType>::mant) * 0.5;
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
is_same_v<AccDataType, int>)
{
return 0;
}
else
{
acc_error =
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
}
return std::max(acc_error, midway_error);
}
template <typename Range, typename RefRange> template <typename Range, typename RefRange>
typename std::enable_if< typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
...@@ -253,11 +377,13 @@ check_err(const Range& out, ...@@ -253,11 +377,13 @@ check_err(const Range& out,
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i)); const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r); err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
...@@ -270,6 +396,7 @@ check_err(const Range& out, ...@@ -270,6 +396,7 @@ check_err(const Range& out,
res = false; res = false;
} }
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
......
...@@ -8,9 +8,19 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES ...@@ -8,9 +8,19 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
) )
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_multiply_multiply_instance ${GEMM_MULTIPLY_MULTIPLY_INSTANCES}) add_instance_library(device_gemm_multiply_multiply_instance ${GEMM_MULTIPLY_MULTIPLY_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using I32 = int;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using MultiplyMultiply = element_wise::MultiplyMultiply;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>
// clang-format oI
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>,
// Memory friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, I8, I8, Tuple<F32, F32>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>
// clang-format oI
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances<Intrawave,
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances<Interwave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
I8,
I8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances<Interwave,
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -6,6 +6,7 @@ set(GROUPED_CONV2D_BWD_WEIGHT ...@@ -6,6 +6,7 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(GROUPED_CONV2D_FWD_DYNAMIC_OP
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instance.cpp)
add_instance_library(device_grouped_conv2d_fwd_dynamic_op_instance ${GROUPED_CONV2D_FWD_DYNAMIC_OP})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
ck::Tuple<>,
NHWGK,
BF16,
BF16,
ck::Tuple<>,
BF16,
PassThrough,
PassThrough,
DynamicUnaryOp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
ck::Tuple<>,
NHWGK,
F16,
F16,
ck::Tuple<>,
F16,
PassThrough,
PassThrough,
DynamicUnaryOp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_f16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_f16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_f16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
ck::Tuple<>,
NHWGK,
F32,
F32,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
DynamicUnaryOp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_f32_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_f32_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_f32_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
ck::Tuple<>,
NHWGK,
int8_t,
int8_t,
ck::Tuple<>,
int8_t,
PassThrough,
PassThrough,
DynamicUnaryOp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_int8_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_int8_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_int8_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -6,6 +6,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT ...@@ -6,6 +6,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(GROUPED_CONV3D_FWD_DYNAMIC_OP
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
add_instance_library(device_grouped_conv3d_fwd_dynamic_op_instance ${GROUPED_CONV3D_FWD_DYNAMIC_OP})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
BF16,
BF16,
ck::Tuple<>,
BF16,
PassThrough,
PassThrough,
DynamicUnaryOp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment