Commit dc0bae32 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into aosewski/wavelet_omniperf

parents 68474822 ba40c2ce
......@@ -17,33 +17,24 @@ template <typename GridwiseSparseEmbedding,
typename BetaDataType,
typename AccDataType,
typename OutType,
typename OutGridDesc>
typename OutGridDesc,
typename EmbElementwiseOperation,
ck::index_t NumEmbeddings>
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
__global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
__global__ void kernel_sparse_embeddings_forward_layernorm(
OutType* p_out,
const ck::Array<EmbType*, NumEmbeddings> p_embs,
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc out_grid_desc,
const AccDataType epsilon)
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
{
GridwiseSparseEmbedding::Run(p_out,
p_emb_a,
p_emb_b,
p_emb_c,
p_index_a,
p_index_b,
p_index_c,
p_gamma,
p_beta,
out_grid_desc,
epsilon);
GridwiseSparseEmbedding::Run(
p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op);
}
template <typename EmbType,
......@@ -53,14 +44,16 @@ template <typename EmbType,
typename AccDataType,
typename OutType,
typename OutGridDesc,
typename EmbElementwiseOperation,
ck::index_t BlockSize,
ck::index_t DimClusterSize,
ck::index_t RowClusterSize,
ck::index_t DimPerBlock, // Row x Dim, along Dim
ck::index_t RowPerBlock, // Row x Dim, along Row
ck::index_t DimThreadSize, // this is actually not vector, but number of registers
ck::index_t RowVectorSize>
struct GridwiseSparseEmbedding3ForwardLayernorm
ck::index_t RowVectorSize,
ck::index_t NumEmbeddings>
struct GridwiseSparseEmbeddingsForwardLayernorm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -97,23 +90,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
BlockwiseWelford<AccDataType, BlockSize, ThreadClusterLength, Sequence<0, 1>>;
__device__ static void Run(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const ck::Array<EmbType*, NumEmbeddings> p_embs,
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc,
const AccDataType epsilon)
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
{
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr auto thread_cluster_desc =
make_cluster_descriptor(Sequence<DimClusterSize, RowClusterSize>{}, Sequence<0, 1>{});
......@@ -141,13 +128,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr auto gamma_beta_buf_desc =
make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize));
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_a;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_b;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_c;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_a;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_b;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_c;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true>,
NumEmbeddings>
in_thread_bufs;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, IndexType, DimPerBlock, true>, NumEmbeddings>
index_bufs;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true> acc_thread_buf;
......@@ -160,42 +145,31 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> var_thread_buf;
auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_a;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_b;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_c;
using src_vector_t = typename decltype(emb_vector_a)::type;
ck::Array<vector_type_maker_t<EmbType, RowVectorSize>, NumEmbeddings> emb_vectors;
auto emb_a = emb_vectors[0];
using src_vector_t = typename decltype(emb_a)::type;
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
IndexType index_a = index_buf_a[Number<current_dim>{}];
IndexType index_b = index_buf_b[Number<current_dim>{}];
IndexType index_c = index_buf_c[Number<current_dim>{}];
auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(EmbType) * RowVectorSize;
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
IndexType index = index_bufs[i_embedding_][Number<current_dim>{}];
int32x4_t emb_res_a =
make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock);
int32x4_t emb_res_b =
make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock);
int32x4_t emb_res_c =
make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock);
emb_vector_a.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_a, thread_offset, 0);
emb_vector_b.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_b, thread_offset, 0);
emb_vector_c.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_c, thread_offset, 0);
int32x4_t emb_res = make_wave_buffer_resource_with_default_range(
p_embs[i_embedding_] + index * RowPerBlock);
emb_vectors(i_embedding_).template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0);
});
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
in_thread_buf_a(Number<register_offset>{}) =
emb_vector_a.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_b(Number<register_offset>{}) =
emb_vector_b.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_c(Number<register_offset>{}) =
emb_vector_c.template AsType<EmbType>()[i_row_vec_];
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
in_thread_bufs(i_embedding_)(Number<register_offset>{}) =
ck::type_convert<AccDataType>(
emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_]);
});
});
});
};
......@@ -205,14 +179,15 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
AccDataType va =
ck::type_convert<AccDataType>(in_thread_buf_a(Number<register_offset>{}));
AccDataType vb =
ck::type_convert<AccDataType>(in_thread_buf_b(Number<register_offset>{}));
AccDataType vc =
ck::type_convert<AccDataType>(in_thread_buf_c(Number<register_offset>{}));
acc_thread_buf(Number<register_offset>{}) += va + vb + vc;
auto in_data_refs = generate_tie(
[&](auto i_embedding_) -> const auto& {
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
},
Number<NumEmbeddings>{});
auto out_data_refs = generate_tie(
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
Number<1>{});
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
});
});
};
......@@ -242,7 +217,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
auto divisor =
1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
......@@ -250,8 +226,7 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
auto acc_val = acc_thread_buf[Number<register_offset>{}];
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) /
sqrt(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) * divisor;
acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
beta_thread_buf[Number<gamma_beta_offset>{}];
......@@ -273,9 +248,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
// first load index
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
// prefer use s_load
index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value];
index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value];
index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value];
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
index_bufs(i_embedding_)(i_idx_) =
p_indexes[i_embedding_][index_start + i_idx_.value];
});
});
// load gamma/beta
......@@ -329,7 +305,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace ck {
enum struct WmmaInstr
{
wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16,
wmma_f16_16x16x16_f16,
wmma_bf16_16x16x16_bf16,
wmma_i32_16x16x16_iu8,
wmma_i32_16x16x16_iu4
};
/*
* WMMA Wave Tile Always MxNxK = 16x16x16
* WAVE32
-----------------------------------
|RC0| | | | | | | | | | | | | | | | SubGroup 0
|RC1| | | | | | | | | | | | | | | |
|RC2| | | | | | | | | | | | | | | |
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC6| | | | | | | | | | | | | | | |
|RC7| | | | | | | | | | | | | | | |
-----------------------------------
| | | | | | | | | | | | | | | | | SubGroup 1
| | | | | | | | | | | | | | | | |
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
-----------------------------------
* WAVE64
-----------------------------------
|RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
|RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
| 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
| 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
| 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
| 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
| | | | | | | | | | | | | | | | |
-----------------------------------
* RC = Register for storing accumalted result
* T = Thread ID
*/
template <WmmaInstr Instr, index_t WaveSize, typename = void>
struct wmma_type
{
};
// A-swizzled
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
// * Data Pixel
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f32_16x16x16_f16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_bf16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f32_16x16x16_bf16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
}
};
#endif
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
bool neg_a,
bool neg_b,
bool clamp,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_i32_16x16x16_iu8_w32<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_i32_16x16x16_iu8_w64<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
}
};
template <typename src_type_a,
typename src_type_b,
typename dst_type,
index_t MPerWmma,
index_t NPerWmma>
struct WmmaSelector
{
template <typename src_type_a_,
typename src_type_b_,
typename dst_type_,
index_t MPerWmma_,
index_t NPerWmma_>
static constexpr auto GetWmma();
template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_f16;
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_bf16;
}
template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{
return WmmaInstr::wmma_f16_16x16x16_f16;
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{
return WmmaInstr::wmma_bf16_16x16x16_bf16;
}
template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu8;
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu4;
}
#endif
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_wmma =
wmma_type<GetWmma<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
__host__ __device__ constexpr WmmaSelector()
{
static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
}
};
template <typename src_type_a,
typename src_type_b,
typename dst_type,
index_t MPerWmma,
index_t NPerWmma,
index_t KPack,
bool TransposeC = false>
struct WmmaGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__host__ __device__ constexpr WmmaGemm()
{
static_assert(NPerWmma == 16 && MPerWmma == 16,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
}
// WMMA output supporting C = A * B
// Vector Write
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2, 6>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
}
__device__ static constexpr index_t GetRegSizePerWmma()
{
return wmma_instr.num_acc_vgprs_per_wave;
}
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, half_t>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
is_same<dst_type, bhalf_t>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
is_same<dst_type, int32_t>::value)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
is_same<dst_type, int32_t>::value)
#endif
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
}
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
__device__ static auto GetSubGroupId()
{
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
}
__device__ static auto GetLaneIdUnderSubGroup()
{
return GetLaneId() % wmma_instr.num_thread_per_subgroups;
}
__device__ static auto GetSwizzledLaneIdLow()
{
return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
return GetSwizzledLaneIdLow();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
return GetLaneIdUnderSubGroup();
}
__device__ static CIndex GetBeginOfThreadBlk()
{
index_t n_offset = GetLaneIdUnderSubGroup();
index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
}
static constexpr auto wmma =
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma;
__host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
}
};
} // namespace ck
......@@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3);
}
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "ck/utility/amd_inline_asm.hpp"
#include "data_type.hpp"
// TODO: Add arch limitation
namespace ck {
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;
template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp);
}
};
/********************************WAVE64 MODE***********************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w64;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck
#endif
......@@ -3,7 +3,9 @@
#pragma once
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
......@@ -114,7 +116,16 @@ static inline __device__ int4_t abs(int4_t x)
};
#endif
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ half_t abs(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x;
};
static inline __device__ bool isnan(float x) { return ::isnan(x); };
......@@ -140,7 +151,12 @@ static inline __device__ bool isnan(int4_t x)
};
#endif
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ bool isnan(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
......
......@@ -251,27 +251,27 @@ constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum o
};
template <InMemoryDataOperationEnum Operation, typename DataType>
struct InMemoryDataOperatonSupportedOnDataType
struct InMemoryDataOperationSupportedOnDataType
{
static constexpr bool value = false;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
......@@ -280,7 +280,7 @@ struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, D
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormBwd : public device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
DscaleDbiasDataType* p_dscale,
DscaleDbiasDataType* p_dbias)
: reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = xyLengths[dim];
i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
dy_invariant_strides_[i] = dyStrides[dim];
dx_invariant_strides_[i] = dxStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
dy_reduce_strides_[i] = dyStrides[dim];
dx_reduce_strides_[i] = dxStrides[dim];
i++;
};
reduceSize_ = std::accumulate(
reduce_lengths_.begin(), reduce_lengths_.end(), 1, std::multiplies<size_t>{});
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
}
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> dy_invariant_strides_;
std::array<index_t, NumInvariantDim> dx_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> dy_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> dx_reduce_strides_;
const XDataType* p_x_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
DscaleDbiasDataType* p_dscale_;
DscaleDbiasDataType* p_dbias_;
bool haveSavedMeanInvVar_;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType epsilon_;
size_t reduceSize_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
using ck::host_common::get_offset_from_index;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t dy_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.dy_invariant_strides_, invariant_index);
size_t dx_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.dx_invariant_strides_, invariant_index);
AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType variance = type_convert<AccDataType>(0.0f);
AccDataType invVar;
int32_t curr_count = 0;
if(arg.haveSavedMeanInvVar_)
{
size_t mean_invVar_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.bnMeanVarStrides_, invariant_index);
mean =
type_convert<AccDataType>(arg.p_savedMean_[mean_invVar_invariant_offset]);
invVar =
type_convert<AccDataType>(arg.p_savedInvVar_[mean_invVar_invariant_offset]);
}
else
{
// compute mean, variance using welford method
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType delta = x - mean;
mean += delta / curr_count;
AccDataType delta2 = x - mean;
variance += delta * delta2;
};
// actual variance
variance = variance / curr_count;
// inv-variance defined as 1/sqrt(epsilon+variance)
invVar =
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
};
AccDataType dbias =
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy
AccDataType dscale =
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on reduced dimensions
// 3) calculate sum(dy * norm_x) on reduced dimensions
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dy_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
size_t dscale_offset = get_offset_from_index<NumInvariantDim>(
arg.bnDscaleDbiasStrides_, invariant_index);
size_t dbias_offset = get_offset_from_index<NumInvariantDim>(
arg.bnDscaleDbiasStrides_, invariant_index);
arg.p_dscale_[dscale_offset] = type_convert<DscaleDbiasDataType>(dscale);
arg.p_dbias_[dbias_offset] = type_convert<DscaleDbiasDataType>(dbias);
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[scale_offset]);
AccDataType multiplier = type_convert<AccDataType>(1.0f) /
type_convert<AccDataType>(arg.reduceSize_) * invVar *
scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias
// - tmp)
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dy_reduce_strides_, reduce_index);
size_t dx_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dx_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
auto dx_offset = dx_invariant_offset + dx_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = multiplier * (type_convert<AccDataType>(arg.reduceSize_) * dy -
dbias - tmpVal);
arg.p_dx_[dx_offset] = type_convert<DxDataType>(dx);
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t i_begin = it * work_per_thread;
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t i = i_begin; i < i_end; ++i)
{
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
threads[it] = joinable_thread(f);
}
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
(void)p_arg;
return (true);
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
dxStrides,
dyStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnDscaleDbiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<DscaleDbiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Backward" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -4,13 +4,13 @@
#pragma once
#include <iostream>
#include <vector>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace ck {
......@@ -23,20 +23,33 @@ template <typename XDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp>
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormFwd<4, 3, YElementwiseOp>
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormFwd : public device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides,
const std::array<int, 3> reduceDims,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* bnScale,
const BiasDataType* bnBias,
......@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
double averageFactor,
MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance)
: p_x_(p_x),
: reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
bnScale_(bnScale),
bnBias_(bnBias),
y_elementwise_op_(y_elementwise_op),
......@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance)
{
ignore = xStrides;
ignore = yStrides;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
ignore = reduceDims;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
n = xyLengths[0];
h = xyLengths[1];
w = xyLengths[2];
c = xyLengths[3];
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = xyLengths[dim];
i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
y_invariant_strides_[i] = yStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
y_reduce_strides_[i] = yStrides[dim];
i++;
};
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor);
......@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr);
}
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnBiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> y_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> y_reduce_strides_;
const XDataType* p_x_;
const ScaleDataType* bnScale_;
const BiasDataType* bnBias_;
......@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
bool resultSave, resultRunning;
index_t n, h, w, c;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType averageFactor_;
AccDataType epsilon_;
......@@ -104,28 +168,28 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
{
float Run(const Argument& arg)
{
auto thread_reduce_func = [&](auto iC) {
index_t offset_C = iC;
using ck::host_common::get_offset_from_index;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t y_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.y_invariant_strides_, invariant_index);
AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType variance = type_convert<AccDataType>(0.0f);
int32_t curr_count = 0;
// compute mean, variance using welford method
for(index_t iN = 0; iN < arg.n; iN++)
{
index_t offset_N = iN * arg.h * arg.w * arg.c;
for(index_t iH = 0; iH < arg.h; iH++)
{
index_t offset_H = iH * arg.w * arg.c;
for(index_t iW = 0; iW < arg.w; iW++)
for(const auto& reduce_index : arg.reduce_index_set_)
{
index_t offset_W = iW * arg.c;
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
auto offset = offset_N + offset_H + offset_W + offset_C;
auto x_offset = x_invariant_offset + x_reduce_offset;
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType delta = x - mean;
......@@ -135,74 +199,88 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
variance += delta * delta2;
};
}
};
// actual variance
variance = variance / curr_count;
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType invVariance =
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
// save the mean/invVariance if required
// save the mean/inv-variance if required
if(arg.resultSave)
{
arg.resultSaveMean_[iC] = type_convert<MeanVarDataType>(mean);
arg.resultSaveInvVariance_[iC] = type_convert<MeanVarDataType>(invVariance);
size_t offset = get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_,
invariant_index);
arg.resultSaveMean_[offset] = type_convert<MeanVarDataType>(mean);
arg.resultSaveInvVariance_[offset] = type_convert<MeanVarDataType>(invVariance);
};
// update the moving average if required
if(arg.resultRunning)
{
size_t offset = get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_,
invariant_index);
AccDataType oneMinusAverageFactor =
type_convert<AccDataType>(1.0) - arg.averageFactor_;
arg.resultRunningMean_[iC] = type_convert<MeanVarDataType>(
type_convert<AccDataType>(arg.resultRunningMean_[iC]) *
arg.resultRunningMean_[offset] = type_convert<MeanVarDataType>(
type_convert<AccDataType>(arg.resultRunningMean_[offset]) *
oneMinusAverageFactor +
mean * arg.averageFactor_);
arg.resultRunningVariance_[iC] = type_convert<MeanVarDataType>(
arg.resultRunningVariance_[iC] * oneMinusAverageFactor +
arg.resultRunningVariance_[offset] = type_convert<MeanVarDataType>(
arg.resultRunningVariance_[offset] * oneMinusAverageFactor +
variance * arg.averageFactor_);
};
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
size_t bias_offset =
get_offset_from_index<NumInvariantDim>(arg.bnBiasStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.bnScale_[scale_offset]);
AccDataType bias = type_convert<AccDataType>(arg.bnBias_[bias_offset]);
// Normalization
for(index_t iN = 0; iN < arg.n; iN++)
{
index_t offset_N = iN * arg.h * arg.w * arg.c;
for(index_t iH = 0; iH < arg.h; iH++)
for(const auto& reduce_index : arg.reduce_index_set_)
{
index_t offset_H = iH * arg.w * arg.c;
for(index_t iW = 0; iW < arg.w; iW++)
{
index_t offset_W = iW * arg.c;
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t y_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.y_reduce_strides_, reduce_index);
auto offset = offset_N + offset_H + offset_W + offset_C;
auto x_offset = x_invariant_offset + x_reduce_offset;
auto y_offset = y_invariant_offset + y_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
AccDataType norm_x = (x - mean) * invVariance;
arg.p_y_[offset] = type_convert<YDataType>(norm_x);
};
}
AccDataType y = scale * norm_x + bias;
arg.y_elementwise_op_(y, y);
arg.p_y_[y_offset] = type_convert<YDataType>(y);
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c + num_thread - 1) / num_thread;
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c);
std::size_t i_begin = it * work_per_thread;
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
for(std::size_t i = i_begin; i < i_end; ++i)
{
thread_reduce_func(ic);
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
......@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl;
str << "Reference_BatchNorm_Forward" << std::endl;
// clang-format on
return str.str();
......
......@@ -8,6 +8,7 @@
#include <array>
#include <algorithm>
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
namespace ck {
......@@ -19,114 +20,205 @@ template <typename XDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3>
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormInfer : public device::DeviceBatchNormInfer<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* bnScale,
const BiasDataType* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
const MeanVarDataType* estimatedMean,
const MeanVarDataType* estimatedVariance,
YDataType* p_y)
: p_x_(p_x),
: reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
bnScale_(bnScale),
bnBias_(bnBias),
epsilon_(epsilon),
y_elementwise_op_(y_elementwise_op),
estimatedMean_(estimatedMean),
estimatedVariance_(estimatedVariance),
p_y_(p_y)
{
ignore = xStrides;
ignore = yStrides;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = xyLengths[dim];
i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
// check invariant_lengths_ and bnScaleBiasMeanVarLengths
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
y_invariant_strides_[i] = yStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
y_reduce_strides_[i] = yStrides[dim];
i++;
};
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
n_ = xyLengths[0];
h_ = xyLengths[1];
w_ = xyLengths[2];
c_ = xyLengths[3];
epsilon_ = type_convert<AccDataType>(epsilon);
}
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnBiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> y_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> y_reduce_strides_;
const XDataType* p_x_;
const ScaleDataType* bnScale_;
const BiasDataType* bnBias_;
double epsilon_;
const YElementwiseOp y_elementwise_op_;
const MeanVarDataType* estimatedMean_;
const MeanVarDataType* estimatedVariance_;
YDataType* p_y_;
index_t n_, h_, w_, c_;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType epsilon_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
auto thread_reduce_func = [&](auto iC) {
index_t offset_C = iC;
AccDataType mean = arg.estimatedMean_[offset_C];
AccDataType variance = arg.estimatedVariance_[offset_C];
using ck::host_common::get_offset_from_index;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t y_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.y_invariant_strides_, invariant_index);
size_t mean_variance_offset =
get_offset_from_index<NumInvariantDim>(arg.bnMeanVarStrides_, invariant_index);
AccDataType mean = arg.estimatedMean_[mean_variance_offset];
AccDataType variance = arg.estimatedVariance_[mean_variance_offset];
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType invVariance =
type_convert<AccDataType>(1.0f) /
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
type_convert<AccDataType>(1.0f) / std::sqrt(arg.epsilon_ + variance);
// Normalization
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
size_t bias_offset =
get_offset_from_index<NumInvariantDim>(arg.bnBiasStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.bnScale_[scale_offset]);
AccDataType bias = type_convert<AccDataType>(arg.bnBias_[bias_offset]);
// normalization
for(const auto& reduce_index : arg.reduce_index_set_)
{
index_t offset_W = iW * arg.c_;
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t y_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.y_reduce_strides_, reduce_index);
auto offset = offset_N + offset_H + offset_W + offset_C;
auto x_offset = x_invariant_offset + x_reduce_offset;
auto y_offset = y_invariant_offset + y_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
AccDataType norm_x = (x - mean) * invVariance;
arg.p_y_[offset] = type_convert<YDataType>(norm_x);
};
}
AccDataType y = scale * norm_x + bias;
arg.y_elementwise_op_(y, y);
arg.p_y_[y_offset] = type_convert<YDataType>(y);
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
std::size_t i_begin = it * work_per_thread;
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
for(std::size_t i = i_begin; i < i_end; ++i)
{
thread_reduce_func(ic);
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
......@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnBiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
const void* estimatedMean,
const void* estimatedVariance,
void* p_y) override
......@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
return std::make_unique<Argument>(xyLengths,
xStrides,
yStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
......@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
static_cast<const ScaleDataType*>(bnScale),
static_cast<const BiasDataType*>(bnBias),
epsilon,
y_elementwise_op,
static_cast<const MeanVarDataType*>(estimatedMean),
static_cast<const MeanVarDataType*>(estimatedVariance),
static_cast<YDataType*>(p_y));
......@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Forward_NHWC_C<" << std::endl;
str << "Reference_BatchNorm_Infer<" << std::endl;
// clang-format on
return str.str();
......
......@@ -90,10 +90,13 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m)
{
AccDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n)
{
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_);
auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/ck.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex>
struct ReferenceReduce : public device::DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{
using IndexDataType = int32_t;
static constexpr int NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumSrcDim = Rank;
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
double alpha,
double beta,
const InDataType* in_host,
OutDataType* out_host,
IndexDataType* out_index_host,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: reduceDims_(reduceDims),
outLengths_(outLengths),
outStrides_(outStrides),
in_host_(in_host),
out_host_(out_host),
out_index_host_(out_index_host),
in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op)
{
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
if constexpr(NumInvariantDim > 0)
{
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = inLengths[dim];
i++;
};
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = inLengths[dim];
};
if constexpr(NumInvariantDim > 0)
{
// check invariant_lengths_ and outLengths
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != outLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
}
if constexpr(NumInvariantDim > 0)
{
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
in_invariant_strides_[i] = inStrides[dim];
i++;
};
};
for(int j = 0, i = 0; j < NumReduceDim; j++)
{
int dim = reduceDims_[j];
in_reduce_strides_[i] = inStrides[dim];
i++;
};
if constexpr(NumInvariantDim > 0)
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumReduceDim>(reduce_lengths_);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
};
const std::array<int, NumReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumReduceDim> reduce_lengths_;
const std::array<index_t, NumDstDim> outLengths_;
const std::array<index_t, NumDstDim> outStrides_;
std::array<index_t, NumInvariantDim> in_invariant_strides_;
std::array<index_t, NumReduceDim> in_reduce_strides_;
const InDataType* in_host_;
OutDataType* out_host_;
IndexDataType* out_index_host_;
const InElementwiseOperation in_elementwise_op_;
const AccElementwiseOperation acc_elementwise_op_;
AccDataType alpha_;
AccDataType beta_;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumReduceDim>> reduce_index_set_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ignore = stream_config;
using ck::float_equal_one;
using ck::float_equal_zero;
using ck::type_convert;
using ck::host_common::get_index_set;
using ck::host_common::get_offset_from_index;
if constexpr(OutputIndex)
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < arg.reduce_index_set_.size(); i++)
{
auto in_offset = get_offset_from_index<NumReduceDim>(
arg.in_reduce_strides_, arg.reduce_index_set_[i]);
auto currVal = type_convert<AccDataType>(arg.in_host_[in_offset]);
arg.in_elementwise_op_(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
arg.acc_elementwise_op_(accuVal, accuVal);
if(!float_equal_one{}(arg.alpha_))
accuVal *= type_convert<AccDataType>(arg.alpha_);
if(!float_equal_zero{}(arg.beta_))
accuVal += type_convert<AccDataType>(arg.out_host_[0]) *
type_convert<AccDataType>(arg.beta_);
arg.out_host_[0] = type_convert<OutDataType>(accuVal);
arg.out_index_host_[0] = accuIndex;
}
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal =
ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
auto in_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.in_invariant_strides_, invariant_index);
for(std::size_t i = 0; i < arg.reduce_index_set_.size(); i++)
{
auto in_reduce_offset = get_offset_from_index<NumReduceDim>(
arg.in_reduce_strides_, arg.reduce_index_set_[i]);
auto currVal = type_convert<AccDataType>(
arg.in_host_[in_invariant_offset + in_reduce_offset]);
arg.in_elementwise_op_(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
arg.acc_elementwise_op_(accuVal, accuVal);
if(!float_equal_one{}(arg.alpha_))
accuVal *= type_convert<AccDataType>(arg.alpha_);
auto dst_offset = get_offset_from_index<NumInvariantDim>(arg.outStrides_,
invariant_index);
if(!float_equal_zero{}(arg.beta_))
accuVal += type_convert<AccDataType>(arg.out_host_[dst_offset]) *
type_convert<AccDataType>(arg.beta_);
arg.out_host_[dst_offset] = type_convert<OutDataType>(accuVal);
arg.out_index_host_[dst_offset] = accuIndex;
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t i_begin = it * work_per_thread;
std::size_t i_end =
std::min((it + 1) * work_per_thread, arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t i = i_begin; i < i_end; i++)
{
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
threads[it] = joinable_thread(f);
}
};
}
else
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(const auto& reduce_index : arg.reduce_index_set_)
{
auto in_offset = get_offset_from_index<NumReduceDim>(arg.in_reduce_strides_,
reduce_index);
auto currVal = type_convert<AccDataType>(arg.in_host_[in_offset]);
arg.in_elementwise_op_(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
arg.acc_elementwise_op_(accuVal, accuVal);
if(!float_equal_one{}(arg.alpha_))
accuVal *= type_convert<AccDataType>(arg.alpha_);
if(!float_equal_zero{}(arg.beta_))
accuVal += type_convert<AccDataType>(arg.out_host_[0]) *
type_convert<AccDataType>(arg.beta_);
arg.out_host_[0] = type_convert<OutDataType>(accuVal);
}
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal =
ReduceOperation::template GetIdentityValue<AccDataType>();
auto in_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.in_invariant_strides_, invariant_index);
for(const auto& reduce_index : arg.reduce_index_set_)
{
auto in_reduce_offset = get_offset_from_index<NumReduceDim>(
arg.in_reduce_strides_, reduce_index);
auto currVal = type_convert<AccDataType>(
arg.in_host_[in_invariant_offset + in_reduce_offset]);
arg.in_elementwise_op_(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
arg.acc_elementwise_op_(accuVal, accuVal);
if(!float_equal_one{}(arg.alpha_))
accuVal *= type_convert<AccDataType>(arg.alpha_);
auto dst_offset = get_offset_from_index<NumInvariantDim>(arg.outStrides_,
invariant_index);
if(!float_equal_zero{}(arg.beta_))
accuVal += type_convert<AccDataType>(arg.out_host_[dst_offset]) *
type_convert<AccDataType>(arg.beta_);
arg.out_host_[dst_offset] = type_convert<OutDataType>(accuVal);
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t i_begin = it * work_per_thread;
std::size_t i_end =
std::min((it + 1) * work_per_thread, arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t i = i_begin; i < i_end; i++)
{
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
threads[it] = joinable_thread(f);
}
};
};
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
ignore = p_arg;
return true;
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
double alpha,
double beta,
const void* in_host,
const void* in_index_host,
void* out_host,
void* out_index_host,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
ignore = in_index_host;
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_host),
static_cast<OutDataType*>(out_host),
static_cast<IndexDataType*>(out_index_host),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_Reduce<" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -24,11 +24,14 @@ struct ReferenceSoftmax : public device::BaseOperator
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
double alpha,
double beta,
const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
: in_(in), out_(out), sm_reduce_dims_(sm_reduce_dims)
{
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
// std::cout << "debug: scalar dims: ";
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
{
......@@ -143,8 +146,8 @@ struct ReferenceSoftmax : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
double alpha,
double beta,
const std::vector<index_t> sm_reduce_dims)
{
return Argument{in, out, alpha, beta, sm_reduce_dims};
......
......@@ -27,8 +27,8 @@ using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>;
// GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor;
......@@ -79,7 +79,8 @@ using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
//
using GK = ck::tensor_layout::convolution::G_K;
using GK_TUPLE = ck::Tuple<GK>;
using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -87,6 +88,9 @@ using Relu = ck::tensor_operation::element_wise::Relu;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......@@ -95,6 +99,13 @@ template <typename Activation>
using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance(
std::vector<std::unique_ptr<
DeviceBatchedContractionMultipleD<1,
2,
3,
1,
F16,
F16,
F16_Tuple,
F16,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Add>>>& instances);
// Contraction + add
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedContractionMultipleD<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Add>>
{
using DeviceOp =
DeviceBatchedContractionMultipleD<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Add>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, ck::half_t> && is_same_v<BDataType, ck::half_t> &&
is_same_v<DDataType, ck::half_t> && is_same_v<EDataType, ck::half_t>)
{
if constexpr(NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1)
{
add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_infer_rank_4_f16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F16, F32, F32, F16, F16>,
ck::Tuple<F16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP32
void add_device_batchnorm_infer_rank_4_f32_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F32, F32, F32, F32, F32>,
ck::Tuple<F32>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// BF16
void add_device_batchnorm_infer_rank_4_bf16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<BF16, F32, F32, BF16, BF16>,
ck::Tuple<BF16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP64
void add_device_batchnorm_infer_rank_4_f64_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F64, F64, F64, F64, F64>,
ck::Tuple<F64>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
template <typename XDataType,
typename YDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
index_t Rank>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<XDataType, MeanVarDataType, MeanVarDataType, ScaleDataType, BiasDataType>,
ck::Tuple<YDataType>,
ck::tensor_operation::element_wise::NormalizeInInfer,
Rank>>
{
using DeviceOp = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<XDataType, MeanVarDataType, MeanVarDataType, ScaleDataType, BiasDataType>,
ck::Tuple<YDataType>,
ck::tensor_operation::element_wise::NormalizeInInfer,
Rank>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,7 +7,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -18,11 +18,8 @@ namespace device {
namespace instance {
using Normalize = ck::tensor_operation::element_wise::Normalize;
using DeviceNormalizeFromMeanMeanSquarePtr = ck::tensor_operation::device::DeviceElementwiseBasePtr<
Tuple<half_t, float, float, half_t, half_t>,
Tuple<half_t>,
Normalize,
2>;
using DeviceNormalizeFromMeanMeanSquarePtr = ck::tensor_operation::device::
DeviceElementwisePtr<Tuple<half_t, float, float, half_t, half_t>, Tuple<half_t>, Normalize, 2>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceNormalizeFromMeanMeanSquarePtr>& instances);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
// GEMM + Add + FastGelu
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>&);
void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>&);
void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>&);
void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>&);
// GEMM + Add + Multiply
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddMultiply>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddMultiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment