Unverified Commit b164b0ef authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into grouped_gemm_multi_abd_fixed_nk_example

parents e786932a 3696fe1c
......@@ -16,9 +16,11 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -32,6 +34,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce>
struct GridwiseNormalizationWelfordVariance_mk_to_mk
{
......@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize);
......@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M =
......@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration,
ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
......@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
var_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
......@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize),
y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
......@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
......@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
});
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
......@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
inv_std_thread_buf(iM);
// gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
......@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
});
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
......@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
......@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
inv_std_thread_buf(iM);
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
......@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
......@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8>
{
......@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
{
......@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8>
{
......@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <typename base_type,
index_t MPerXdlops,
......@@ -792,7 +784,6 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
......@@ -804,9 +795,7 @@ struct MfmaSelector
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
#endif
#if defined CK_ENABLE_BF8
template <>
static constexpr auto GetMfma<bf8_t, 32, 32>()
{
......@@ -818,9 +807,7 @@ struct MfmaSelector
{
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{
......@@ -832,9 +819,7 @@ struct MfmaSelector
{
return MfmaInstr::mfma_f32_16x16x32f8bf8;
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{
......@@ -846,7 +831,6 @@ struct MfmaSelector
{
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
#endif
static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
......@@ -1051,18 +1035,10 @@ struct XdlopsGemm
static_assert(
is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value
#endif
#if defined CK_ENABLE_BF8
|| is_same<base_type, bf8_t>::value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|| (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value)
#endif
,
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value ||
is_same<base_type, bf8_t>::value ||
(is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
#pragma once
namespace ck {
......@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
......@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8bf8;
......@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8;
......@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8f8;
......@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif
}
};
#endif
} // namespace ck
#endif
......@@ -9,15 +9,9 @@ namespace ck {
using bhalf_t = ushort;
using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4);
#endif
#if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
// vector_type
template <typename T, index_t N>
......@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8
template <>
struct scalar_type<f8_t>
{
using type = f8_t;
static constexpr index_t vector_size = 1;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
#endif
template <typename T>
struct vector_type<T, 1>
......@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T>
struct NumericLimits
......@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <>
struct NumericLimits<f8_t>
{
......@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericLimits<bf8_t>
{
......@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
};
#endif
template <typename T>
struct NumericUtils
......@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t>
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
//
} // namespace ck
......@@ -6,8 +6,6 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck {
// fp8 rounding modes
......@@ -244,5 +242,3 @@ __host__ __device__ Y cast_from_f8(X x)
}
} // namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
......@@ -95,7 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
#if defined CK_ENABLE_FP8
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
......@@ -146,7 +145,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#else
#elif 0
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -154,6 +153,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<f8_t>(type_convert<float>(x));
#endif
}
......@@ -164,14 +165,14 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
#elif 0
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
......@@ -222,7 +223,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<bf8_t>(type_convert<float>(x));
#else
#elif 0
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -230,6 +231,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<bf8_t>(type_convert<float>(x));
#endif
}
......@@ -240,12 +243,13 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
#elif 0
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif
}
#endif
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
......@@ -308,7 +312,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
......@@ -354,12 +357,10 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<f8_t>(type_convert<float>(x));
return f8_convert_sr<f8_t>(type_convert<float>(x));
#endif
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
......@@ -406,9 +407,8 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<bf8_t>(type_convert<float>(x));
return f8_convert_sr<bf8_t>(type_convert<float>(x));
#endif
}
#endif
} // namespace ck
......@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation>
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation>
struct ReferenceGroupnorm : public device::BaseOperator
{
// x = [N, H, W, G, C]
......@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op,
Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths,
AccDataType epsilon)
ComputeDataType epsilon)
: x_(x),
gamma_(gamma),
beta_(beta),
y_(y),
acc_elementwise_op_(acc_elementwise_op),
save_mean_(save_mean),
save_inv_std_(save_inv_std),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths),
epsilon_(epsilon)
{
......@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<XDataType> gamma_;
const Tensor<XDataType> beta_;
Tensor<YDataType>& y_;
AccElementwiseOperation acc_elementwise_op_;
Tensor<SaveMeanInvStdDataType>& save_mean_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_;
AccDataType epsilon_;
ComputeDataType epsilon_;
};
// Invoker
......@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
int G = arg.lengths_[3];
int C = arg.lengths_[4];
Tensor<AccDataType> mean({N, G});
Tensor<AccDataType> var({N, G});
Tensor<ComputeDataType> mean({N, G});
Tensor<ComputeDataType> var({N, G});
// Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC
......@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
{
for(int g = 0; g < G; ++g)
{
AccDataType mean_val = type_convert<AccDataType>(0.0f);
AccDataType var_val = type_convert<AccDataType>(0.0f);
int32_t curr_count = 0;
ComputeDataType mean_val = type_convert<ComputeDataType>(0.0f);
ComputeDataType var_val = type_convert<ComputeDataType>(0.0f);
int32_t curr_count = 0;
for(int h = 0; h < H; ++h)
{
......@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
for(int c = 0; c < C; ++c)
{
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c));
AccDataType delta = x - mean_val;
ComputeDataType x =
type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
ComputeDataType delta = x - mean_val;
mean_val += delta / curr_count;
AccDataType delta2 = x - mean_val;
ComputeDataType delta2 = x - mean_val;
var_val += delta * delta2;
}
}
......@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
mean(n, g) = mean_val;
var(n, g) = var_val / curr_count;
arg.save_mean_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(mean(n, g));
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n, g) + arg.epsilon_);
arg.save_inv_std_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
}
}
......@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
{
for(int c = 0; c < C; ++c)
{
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c));
AccDataType gamma = type_convert<AccDataType>(arg.gamma_(g, c));
AccDataType beta = type_convert<AccDataType>(arg.beta_(g, c));
AccDataType mean_val = type_convert<AccDataType>(mean(n, g));
AccDataType var_val = type_convert<AccDataType>(var(n, g));
AccDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.acc_elementwise_op_(y, y);
ComputeDataType x =
type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
ComputeDataType gamma =
type_convert<ComputeDataType>(arg.gamma_(g, c));
ComputeDataType beta =
type_convert<ComputeDataType>(arg.beta_(g, c));
ComputeDataType mean_val =
type_convert<ComputeDataType>(mean(n, g));
ComputeDataType var_val = type_convert<ComputeDataType>(var(n, g));
ComputeDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.y_elementwise_op_(y, y);
arg.y_(n, h, w, g, c) = type_convert<YDataType>(y);
}
}
......@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op,
Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths,
AccDataType epsilon)
ComputeDataType epsilon)
{
return Argument{x, gamma, beta, y, acc_elementwise_op, lengths, epsilon};
return Argument{
x, gamma, beta, y, save_mean, save_inv_std, y_elementwise_op, lengths, epsilon};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
struct ReferenceLayernorm : public device::BaseOperator
......@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op,
Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims,
AccDataType epsilon)
ComputeDataType epsilon)
: x_m_n_(x_m_n),
gamma_n_(gamma_n),
beta_n_(beta_n),
y_m_n_(y_m_n),
acc_elementwise_op_(acc_elementwise_op),
save_mean_m_(save_mean_m),
save_inv_std_m_(save_inv_std_m),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths),
reduceDims_(reduceDims),
epsilon_(epsilon)
......@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<XDataType> gamma_n_;
const Tensor<XDataType> beta_n_;
Tensor<YDataType>& y_m_n_;
AccElementwiseOperation acc_elementwise_op_;
Tensor<SaveMeanInvStdDataType>& save_mean_m_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_m_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_;
std::vector<index_t> reduceDims_;
AccDataType epsilon_;
ComputeDataType epsilon_;
};
// Invoker
......@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator
int M = arg.lengths_[0];
int N = arg.lengths_[1];
Tensor<AccDataType> mean({M});
Tensor<AccDataType> var({M});
Tensor<ComputeDataType> mean({M});
Tensor<ComputeDataType> var({M});
for(int m = 0; m < M; ++m)
{
......@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int n = 0; n < N; ++n)
{
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
mean(m) += x_val;
var(m) += x_val * x_val;
}
......@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m)
{
AccDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n)
{
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val);
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
auto gamma_val = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(n));
auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * gamma_val) + beta_val;
arg.y_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
}
arg.save_mean_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(mean(m));
arg.save_inv_std_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
}
return 0;
......@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op,
Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims,
AccDataType epsilon)
ComputeDataType epsilon)
{
return Argument{
x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon};
return Argument{x_m_n,
gamma_n,
beta_n,
y_m_n,
save_mean_m,
save_inv_std_m,
y_elementwise_op,
lengths,
reduceDims,
epsilon};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -20,12 +20,8 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using I32 = int32_t;
#if defined CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#if defined CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using Empty_Tuple = ck::Tuple<>;
......
......@@ -240,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
......@@ -267,17 +269,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
}
#endif
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
#endif
#if defined(DL_KERNELS) && defined(CK_ENABLE_FP32)
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
......@@ -306,14 +314,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
}
#endif
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
......
......@@ -98,30 +98,31 @@ struct DeviceOperationInstanceFactory<
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
}
......
......@@ -98,6 +98,26 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
instances);
#endif
template <typename ADataType,
......@@ -105,7 +125,8 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
typename ComputeType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
BLayout,
......@@ -115,7 +136,8 @@ struct DeviceOperationInstanceFactory<
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
ck::tensor_operation::element_wise::PassThrough,
ComputeType>>
{
using DeviceOp = DeviceGemmSplitK<ALayout,
BLayout,
......@@ -125,14 +147,15 @@ struct DeviceOperationInstanceFactory<
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
ck::tensor_operation::element_wise::PassThrough,
ComputeType>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
is_same_v<CDataType, float> && is_same_v<ComputeType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -157,8 +180,8 @@ struct DeviceOperationInstanceFactory<
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -183,8 +206,8 @@ struct DeviceOperationInstanceFactory<
}
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -207,8 +230,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -231,6 +254,31 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, f8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
......
......@@ -6,8 +6,6 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
template <index_t NDSpatial,
typename ALayout,
typename BLayout,
typename CLayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_wmma_f16_instances =
std::tuple<
// clang-format off
//#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
//#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
// blocksize=256
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 16>, 4>,
// blocksize=128
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 32, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 8, 8, 16, 16, 2, 4, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 8, 8, 16, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 32, 8, 8, 16, 16, 1, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 64, 8, 8, 16, 16, 2, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 32, 8, 8, 16, 16, 4, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 16, 8, 8, 16, 16, 4, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 16, 8, 8, 16, 16, 2, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
template <index_t NDSpatial,
typename ALayout,
typename BLayout,
typename CLayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_wmma_i8_instances =
std::tuple<
// clang-format off
//#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
//#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
// blocksize=256
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>,
// blocksize=128
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 256, 8, 8, 16, 16, 4, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 256, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 32, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 2>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 128, 8, 8, 16, 16, 8, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 128, 8, 8, 16, 16, 2, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 64, 8, 8, 16, 16, 8, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 8, 8, 16, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// blocksize=32
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 64, 8, 8, 16, 16, 4, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 16, 8, 8, 16, 16, 4, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -163,6 +163,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
......@@ -177,6 +201,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
......@@ -202,6 +251,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
......@@ -231,6 +304,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
BF8,
F8>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef DL_KERNELS
// dl
......@@ -529,8 +627,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
......@@ -539,9 +637,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
......@@ -552,8 +649,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
else if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, NWGK>)
if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, NWGK>)
{
#ifdef DL_KERNELS
#ifdef CK_ENABLE_FP32
......@@ -564,16 +661,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
op_ptrs);
......@@ -582,7 +678,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
}
}
else if constexpr(NumDimSpatial == 2)
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
......@@ -600,8 +696,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
......@@ -612,9 +708,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
......@@ -625,8 +720,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
......@@ -641,8 +736,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
......@@ -653,9 +748,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
......@@ -667,7 +761,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
}
}
else if constexpr(NumDimSpatial == 3)
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
......@@ -685,8 +779,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
......@@ -694,12 +788,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
......@@ -708,10 +805,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
......@@ -726,10 +833,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> &&
is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
......@@ -737,12 +843,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
......@@ -752,10 +861,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> &&
is_same_v<ComputeTypeA, bf8_t> && is_same_v<ComputeTypeB, f8_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
op_ptrs);
......
......@@ -19,13 +19,13 @@ namespace instance {
#ifdef CK_ENABLE_FP16
// FP16
void add_device_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&);
void add_device_normalization_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 4, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&);
void add_device_normalization_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
......@@ -42,14 +42,15 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormalization<
XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>>
......@@ -57,8 +58,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
using DeviceOp = DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>;
......@@ -68,7 +69,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
......@@ -86,7 +88,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
......
......@@ -19,7 +19,7 @@ namespace instance {
// FP16
void add_device_normalization_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Swish, 5, 3>>>&);
// FP32
void add_device_normalization_rank_5_3_swish_f32_instances(
......@@ -27,20 +27,21 @@ void add_device_normalization_rank_5_3_swish_f32_instances(
// [x, gamma, beta, y] = [f16, f32, f32, f16]
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F16, F32, Swish, 5, 3>>>&);
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::Swish,
Rank,
NumReduceDim>>
......@@ -48,8 +49,8 @@ struct DeviceOperationInstanceFactory<
using DeviceOp = DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::Swish,
Rank,
NumReduceDim>;
......@@ -59,7 +60,8 @@ struct DeviceOperationInstanceFactory<
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
......@@ -67,7 +69,8 @@ struct DeviceOperationInstanceFactory<
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
......@@ -75,7 +78,8 @@ struct DeviceOperationInstanceFactory<
}
}
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F16>)
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
......
......@@ -230,7 +230,6 @@ check_err(const Range& out,
return res;
}
#if defined CK_ENABLE_FP8
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f8_t>),
......@@ -275,9 +274,7 @@ check_err(const Range& out,
}
return res;
}
#endif
#if defined CK_ENABLE_BF8
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
......@@ -322,7 +319,6 @@ check_err(const Range& out,
}
return res;
}
#endif
} // namespace utils
} // namespace ck
......@@ -2,44 +2,44 @@ function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}")
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
set(type1 "_i8")
endif()
#make an exception for reduction kernels
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance" OR ${source} MATCHES "device_image_to_column")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
endforeach()
if(test EQUAL 1)
message("removing instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
......@@ -49,8 +49,10 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 0)
message("add_instance_library ${INSTANCE_NAME}")
else()
message("skip_instance_libary ${INSTANCE_NAME}")
endif()
#message("add_instance_library returns ${result}")
set(result ${result} PARENT_SCOPE)
endfunction(add_instance_library INSTANCE_NAME)
......@@ -58,65 +60,70 @@ endfunction(add_instance_library INSTANCE_NAME)
file(GLOB dir_list LIST_DIRECTORIES true *)
set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list})
set(target_dir)
IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
set(target_dir)
IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
message("fp8 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
endif()
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
message("fp16 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
endif()
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
message("fp32 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
endif()
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
endif()
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
message("bf16 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
endif()
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
message("int8 instance found!")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "_fp8" OR
NOT "${cmake_instance}" MATCHES "_f8" OR
NOT "${cmake_instance}" MATCHES "_fp16" OR
NOT "${cmake_instance}" MATCHES "_f16" OR
NOT "${cmake_instance}" MATCHES "_fp32" OR
NOT "${cmake_instance}" MATCHES "_f32" OR
NOT "${cmake_instance}" MATCHES "_fp64" OR
NOT "${cmake_instance}" MATCHES "_f64" OR
NOT "${cmake_instance}" MATCHES "_bf16" OR
NOT "${cmake_instance}" MATCHES "_int8" OR
NOT "${cmake_instance}" MATCHES "_i8" OR
NOT "${cmake_instance}" MATCHES "_int4" OR
NOT DEFINED DTYPES)
message("instance should be built for all types!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
message("quantization instances will not be built!")
set(add_inst 0)
endif()
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
endif()
if(NOT ("${cmake_instance}" MATCHES "_fp8" OR
"${cmake_instance}" MATCHES "_f8" OR
"${cmake_instance}" MATCHES "_fp16" OR
"${cmake_instance}" MATCHES "_f16" OR
"${cmake_instance}" MATCHES "_fp32" OR
"${cmake_instance}" MATCHES "_f32" OR
"${cmake_instance}" MATCHES "_fp64" OR
"${cmake_instance}" MATCHES "_f64" OR
"${cmake_instance}" MATCHES "_bf16" OR
"${cmake_instance}" MATCHES "_int8" OR
"${cmake_instance}" MATCHES "_i8" OR
"${cmake_instance}" MATCHES "_int4"))
message("instance should be built for all types!")
set(add_inst 1)
endif()
if(NOT DEFINED DTYPES)
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
message("quantization instances will not be built!")
set(add_inst 0)
endif()
if(add_inst EQUAL 1)
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
endif()
ENDIF()
endif()
if(("${cmake_instance}" MATCHES "ONLY DL_KERNELS") AND (NOT DEFINED DL_KERNELS))
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
if((add_inst EQUAL 1))
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
message("add_instance_directory ${subdir_path}")
else()
message("skip_instance_directory ${subdir_path}")
endif()
ENDIF()
ENDFOREACH()
add_library(device_operations STATIC ${CK_DEVICE_INSTANCES})
......@@ -158,11 +165,11 @@ target_compile_options(device_operations PRIVATE
# install(TARGETS device_operations LIBRARY DESTINATION lib)
rocm_install(TARGETS device_operations
EXPORT device_operationsTargets)
EXPORT device_operationsTargets)
rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
rocm_install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
......@@ -15,6 +15,10 @@ list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_in
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp)
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp)
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment