Unverified Commit 24af0144 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 961f5e9e b79bbbc2
...@@ -75,4 +75,63 @@ struct ThreadwiseWelford ...@@ -75,4 +75,63 @@ struct ThreadwiseWelford
int max_count_; int max_count_;
}; };
template <typename T,
typename SrcMeanVarCountThreadDesc_M_K,
typename DstMeanVarThreadDesc_M,
bool GetActualVariance = false>
struct ThreadwiseWelfordMerge
{
static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{};
static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{};
static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
__device__ static void
Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b)
{
int count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count;
var_a += var_b + delta * delta * count_a * count_b_over_count;
count_a = count;
}
template <typename SrcMeanBufferType,
typename SrcVarBufferType,
typename SrcCountBufferType,
typename DstMeanBufferType,
typename DstVarBufferType,
typename DstCountBufferType>
__device__ static void Run(const SrcMeanBufferType& src_mean_buf,
const SrcVarBufferType& src_var_buf,
const SrcCountBufferType& src_count_buf,
DstMeanBufferType& dst_mean_buf,
DstVarBufferType& dst_var_buf,
DstCountBufferType& dst_count_buf)
{
static_for<0, src_length_m, 1>{}([&](auto iM) {
static_for<0, src_length_k, 1>{}([&](auto iK) {
constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
Merge(dst_mean_buf(iM),
dst_var_buf(iM),
dst_count_buf(iM),
src_mean_buf[Number<src_offset>{}],
src_var_buf[Number<src_offset>{}],
src_count_buf[Number<src_offset>{}]);
});
if constexpr(GetActualVariance)
{
dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM];
};
});
};
};
} // namespace ck } // namespace ck
...@@ -594,6 +594,7 @@ struct XdlopsGemm ...@@ -594,6 +594,7 @@ struct XdlopsGemm
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>; using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
...@@ -822,6 +823,16 @@ struct XdlopsGemm ...@@ -822,6 +823,16 @@ struct XdlopsGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
} }
__device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{}; static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace ck {
namespace tensor_operation {
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
device::TensorSpecialization TensorSpec>
static auto MakeGridDescriptorPair(const std::vector<index_t>& gs_ms_ns_lengths_vec,
const std::vector<index_t>& gs_ms_ns_strides_vec)
{
if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
{
throw std::runtime_error("wrong! dimension must match input lengths");
}
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto gs_ms_ns_lengths =
to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto gs_ms_ns_strides =
to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimG + NumDimM, NumDimG + NumDimM + NumDimN, 1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
if constexpr(TensorSpec == device::TensorSpecialization::Packed)
{
auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(G, M, N),
make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(M, N),
make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const auto grid_desc_g_mraw_nraw =
transform_tensor_descriptor(grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto c_ms_ns_lengths = to_tuple(
gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
}
}
template <typename NumDims_G_M_N_K_O, // Sequence<>
typename PerBlock_M_N_K_O, // Sequence<>
device::GemmSpecialization GemmSpec,
device::TensorSpecialization ASpec,
device::TensorSpecialization B0Spec,
device::TensorSpecialization B1Spec,
device::TensorSpecialization CSpec>
struct TransformBatchedContractionContractionToBatchedGemmGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
static constexpr auto matrix_padder =
device::GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, OPerBlock};
//
// A
//
static auto MakeAGridDescriptorPair(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
a_gs_ms_ks_strides_vec);
}
// TODO: rename to G_MRaw_KRaw
static auto MakeAGridDescriptor_G_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
}
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return matrix_padder.PadADescriptor_M_K(
MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
}
template <typename AGridDesc_M_K, typename Number>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// B (alias of B0)
//
static auto MakeB0GridDescriptorPair(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
b0_gs_ns_ks_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeB0GridDescriptor_G_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
{
return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
}
static auto MakeB0GridDescriptor_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return matrix_padder.PadBDescriptor_N_K(
MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
}
template <typename BGridDesc_N_K, typename Number>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// B1
//
static auto MakeB1GridDescriptorPair(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
b1_gs_os_ns_strides_vec);
}
// TODO: rename to G_NRaw_KRaw
static auto MakeB1GridDescriptor_G_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
{
return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
}
static auto MakeB1GridDescriptor_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return matrix_padder.PadB1Descriptor_N_K(
MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
}
template <typename B1GridDesc_N_K, typename Number>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
{
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// C
//
static auto MakeCGridDescriptorPair(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
c_gs_ms_os_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
{
return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
}
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
{
return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
}
};
} // namespace tensor_operation
} // namespace ck
...@@ -9,46 +9,61 @@ ...@@ -9,46 +9,61 @@
#include <algorithm> #include <algorithm>
#include <thread> #include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
template <typename InOutDataType, typename AccDataType> template <typename XDataType,
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatchNormFwd<4, 3> typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp>
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormFwd<4, 3, YElementwiseOp>
{ {
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(const std::array<index_t, 4> xyLengths, Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, 4> yStrides,
const std::array<int, 3> reduceDims,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleBiasMeanVarStrides, const std::array<index_t, 1> bnScaleStrides,
const InOutDataType* p_x, const std::array<index_t, 1> bnBiasStrides,
const AccDataType* bnScale, const std::array<index_t, 1> bnMeanVarStrides,
const AccDataType* bnBias, const XDataType* p_x,
InOutDataType* p_y, const ScaleDataType* bnScale,
double exponentialAverageFactor, const BiasDataType* bnBias,
AccDataType* resultRunningMean,
AccDataType* resultRunningVariance,
double epsilon, double epsilon,
AccDataType* resultSaveMean, const YElementwiseOp y_elementwise_op,
AccDataType* resultSaveInvVariance) YDataType* p_y,
MeanVarDataType* resultSaveMean,
MeanVarDataType* resultSaveInvVariance,
double averageFactor,
MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance)
: p_x_(p_x), : p_x_(p_x),
bnScale_(bnScale), bnScale_(bnScale),
bnBias_(bnBias), bnBias_(bnBias),
y_elementwise_op_(y_elementwise_op),
p_y_(p_y), p_y_(p_y),
resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance),
resultSaveMean_(resultSaveMean), resultSaveMean_(resultSaveMean),
resultSaveInvVariance_(resultSaveInvVariance), resultSaveInvVariance_(resultSaveInvVariance),
exponentialAverageFactor_(exponentialAverageFactor), resultRunningMean_(resultRunningMean),
epsilon_(epsilon) resultRunningVariance_(resultRunningVariance)
{ {
(void)xStrides; ignore = xStrides;
(void)yStrides; ignore = yStrides;
(void)bnScaleBiasMeanVarStrides; ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
ignore = reduceDims;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3]) bnScaleBiasMeanVarLengths[0] != xyLengths[3])
...@@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
w = xyLengths[2]; w = xyLengths[2];
c = xyLengths[3]; c = xyLengths[3];
epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor);
resultSave = (resultSaveMean != nullptr && resultSaveInvVariance != nullptr); resultSave = (resultSaveMean != nullptr && resultSaveInvVariance != nullptr);
resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr); resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr);
} }
const InOutDataType* p_x_; const XDataType* p_x_;
const AccDataType* bnScale_; const ScaleDataType* bnScale_;
const AccDataType* bnBias_; const BiasDataType* bnBias_;
InOutDataType* p_y_; const YElementwiseOp y_elementwise_op_;
YDataType* p_y_;
AccDataType* resultRunningMean_; MeanVarDataType* resultSaveMean_;
AccDataType* resultRunningVariance_; MeanVarDataType* resultSaveInvVariance_;
AccDataType* resultSaveMean_; MeanVarDataType* resultRunningMean_;
AccDataType* resultSaveInvVariance_; MeanVarDataType* resultRunningVariance_;
bool resultSave, resultRunning; bool resultSave, resultRunning;
index_t n, h, w, c; index_t n, h, w, c;
double exponentialAverageFactor_; AccDataType averageFactor_;
double epsilon_; AccDataType epsilon_;
}; };
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
...@@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto thread_reduce_func = [&](auto iC) { auto thread_reduce_func = [&](auto iC) {
AccDataType reduceSize = type_convert<AccDataType>(arg.n) *
type_convert<AccDataType>(arg.h) *
type_convert<AccDataType>(arg.w);
index_t offset_C = iC; index_t offset_C = iC;
AccDataType mean = type_convert<AccDataType>(0.0f); AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType meansquare = type_convert<AccDataType>(0.0f); AccDataType variance = type_convert<AccDataType>(0.0f);
int32_t curr_count = 0;
// compute mean, meanquare, variance, invVariance // compute mean, variance using welford method
for(index_t iN = 0; iN < arg.n; iN++) for(index_t iN = 0; iN < arg.n; iN++)
{ {
index_t offset_N = iN * arg.h * arg.w * arg.c; index_t offset_N = iN * arg.h * arg.w * arg.c;
...@@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
auto offset = offset_N + offset_H + offset_W + offset_C; auto offset = offset_N + offset_H + offset_W + offset_C;
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]); AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
mean += x; AccDataType delta = x - mean;
meansquare += x * x;
mean += delta / curr_count;
AccDataType delta2 = x - mean;
variance += delta * delta2;
}; };
} }
}; };
mean = mean / reduceSize; // actual variance
meansquare = meansquare / reduceSize; variance = variance / curr_count;
AccDataType variance = meansquare - mean * mean;
AccDataType invVariance = AccDataType invVariance =
type_convert<AccDataType>(1.0f) / type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
// save the mean/invVariance if required // save the mean/invVariance if required
if(arg.resultSave) if(arg.resultSave)
{ {
arg.resultSaveMean_[iC] = mean; arg.resultSaveMean_[iC] = type_convert<MeanVarDataType>(mean);
arg.resultSaveInvVariance_[iC] = invVariance; arg.resultSaveInvVariance_[iC] = type_convert<MeanVarDataType>(invVariance);
}; };
// update the moving average if required // update the moving average if required
if(arg.resultRunning) if(arg.resultRunning)
{ {
arg.resultRunningMean_[iC] = AccDataType oneMinusAverageFactor =
arg.resultRunningMean_[iC] * type_convert<AccDataType>(1.0) - arg.averageFactor_;
type_convert<AccDataType>(1.0 - arg.exponentialAverageFactor_) + arg.resultRunningMean_[iC] = type_convert<MeanVarDataType>(
mean * arg.exponentialAverageFactor_; type_convert<AccDataType>(arg.resultRunningMean_[iC]) *
arg.resultRunningVariance_[iC] = oneMinusAverageFactor +
arg.resultRunningVariance_[iC] * mean * arg.averageFactor_);
type_convert<AccDataType>(1.0 - arg.exponentialAverageFactor_) + arg.resultRunningVariance_[iC] = type_convert<MeanVarDataType>(
variance * arg.exponentialAverageFactor_; arg.resultRunningVariance_[iC] * oneMinusAverageFactor +
variance * arg.averageFactor_);
}; };
// Normalization // Normalization
...@@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType norm_x = AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC]; arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
arg.p_y_[offset] = type_convert<InOutDataType>(norm_x); arg.p_y_[offset] = type_convert<YDataType>(norm_x);
}; };
} }
}; };
...@@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
MakeArgumentPointer(const std::array<index_t, 4> xyLengths, MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides, const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, 4> yStrides,
const std::array<int, 3> reduceDims,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleBiasMeanVarStrides, const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
void* p_y,
double exponentialAverageFactor,
void* resultRunningMean,
void* resultRunningVariance,
double epsilon, double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y,
void* resultSaveMean, void* resultSaveMean,
void* resultSaveInvVariance) override void* resultSaveInvVariance,
double averageFactor,
void* resultRunningMean,
void* resultRunningVariance) override
{ {
return std::make_unique<Argument>(xyLengths, return std::make_unique<Argument>(xyLengths,
xStrides, xStrides,
yStrides, yStrides,
reduceDims,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleBiasMeanVarStrides, bnScaleStrides,
static_cast<const InOutDataType*>(p_x), bnBiasStrides,
static_cast<const AccDataType*>(bnScale), bnMeanVarStrides,
static_cast<const AccDataType*>(bnBias), static_cast<const XDataType*>(p_x),
static_cast<InOutDataType*>(p_y), static_cast<const ScaleDataType*>(bnScale),
exponentialAverageFactor, static_cast<const BiasDataType*>(bnBias),
static_cast<AccDataType*>(resultRunningMean),
static_cast<AccDataType*>(resultRunningVariance),
epsilon, epsilon,
static_cast<AccDataType*>(resultSaveMean), y_elementwise_op,
static_cast<AccDataType*>(resultSaveInvVariance)); static_cast<YDataType*>(p_y),
static_cast<MeanVarDataType*>(resultSaveMean),
static_cast<MeanVarDataType*>(resultSaveInvVariance),
averageFactor,
static_cast<MeanVarDataType*>(resultRunningMean),
static_cast<MeanVarDataType*>(resultRunningVariance));
}; };
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
......
...@@ -14,7 +14,12 @@ namespace ck { ...@@ -14,7 +14,12 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
template <typename InOutDataType, typename AccDataType> template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3> struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3>
{ {
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
...@@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
const std::array<index_t, 4> xStrides, const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleBiasMeanVarStrides, const std::array<index_t, 1> bnScaleStrides,
const InOutDataType* p_x, const std::array<index_t, 1> bnBiasStrides,
const AccDataType* bnScale, const std::array<index_t, 1> bnMeanVarStrides,
const AccDataType* bnBias, const XDataType* p_x,
const ScaleDataType* bnScale,
const BiasDataType* bnBias,
double epsilon, double epsilon,
const AccDataType* estimatedMean, const MeanVarDataType* estimatedMean,
const AccDataType* estimatedVariance, const MeanVarDataType* estimatedVariance,
InOutDataType* p_y) YDataType* p_y)
: p_x_(p_x), : p_x_(p_x),
bnScale_(bnScale), bnScale_(bnScale),
bnBias_(bnBias), bnBias_(bnBias),
...@@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
estimatedVariance_(estimatedVariance), estimatedVariance_(estimatedVariance),
p_y_(p_y) p_y_(p_y)
{ {
(void)xStrides; ignore = xStrides;
(void)yStrides; ignore = yStrides;
(void)bnScaleBiasMeanVarStrides; ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3]) bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!"); throw std::runtime_error("Invalid tensor dimensions!");
n = xyLengths[0]; n_ = xyLengths[0];
h = xyLengths[1]; h_ = xyLengths[1];
w = xyLengths[2]; w_ = xyLengths[2];
c = xyLengths[3]; c_ = xyLengths[3];
} }
const InOutDataType* p_x_; const XDataType* p_x_;
const AccDataType* bnScale_; const ScaleDataType* bnScale_;
const AccDataType* bnBias_; const BiasDataType* bnBias_;
double epsilon_; double epsilon_;
const AccDataType* estimatedMean_; const MeanVarDataType* estimatedMean_;
const AccDataType* estimatedVariance_; const MeanVarDataType* estimatedVariance_;
InOutDataType* p_y_; YDataType* p_y_;
index_t n, h, w, c; index_t n_, h_, w_, c_;
}; };
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
...@@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance); std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
// Normalization // Normalization
for(index_t iN = 0; iN < arg.n; iN++) for(index_t iN = 0; iN < arg.n_; iN++)
{ {
index_t offset_N = iN * arg.h * arg.w * arg.c; index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h; iH++) for(index_t iH = 0; iH < arg.h_; iH++)
{ {
index_t offset_H = iH * arg.w * arg.c; index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w; iW++) for(index_t iW = 0; iW < arg.w_; iW++)
{ {
index_t offset_W = iW * arg.c; index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C; auto offset = offset_N + offset_H + offset_W + offset_C;
...@@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
AccDataType norm_x = AccDataType norm_x =
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC]; arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
arg.p_y_[offset] = type_convert<InOutDataType>(norm_x); arg.p_y_[offset] = type_convert<YDataType>(norm_x);
}; };
} }
}; };
}; };
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c + num_thread - 1) / num_thread; std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread); std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it) for(std::size_t it = 0; it < num_thread; ++it)
{ {
std::size_t ic_begin = it * work_per_thread; std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c); std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
auto f = [=] { auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic) for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
...@@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
const std::array<index_t, 4> xStrides, const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> yStrides, const std::array<index_t, 4> yStrides,
const std::array<index_t, 1> bnScaleBiasMeanVarLengths, const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<index_t, 1> bnScaleBiasMeanVarStrides, const std::array<index_t, 1> bnScaleStrides,
const std::array<index_t, 1> bnBiasStrides,
const std::array<index_t, 1> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
...@@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat ...@@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
xStrides, xStrides,
yStrides, yStrides,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleBiasMeanVarStrides, bnScaleStrides,
static_cast<const InOutDataType*>(p_x), bnBiasStrides,
static_cast<const AccDataType*>(bnScale), bnMeanVarStrides,
static_cast<const AccDataType*>(bnBias), static_cast<const XDataType*>(p_x),
static_cast<const ScaleDataType*>(bnScale),
static_cast<const BiasDataType*>(bnBias),
epsilon, epsilon,
static_cast<const AccDataType*>(estimatedMean), static_cast<const MeanVarDataType*>(estimatedMean),
static_cast<const AccDataType*>(estimatedVariance), static_cast<const MeanVarDataType*>(estimatedVariance),
static_cast<InOutDataType*>(p_y)); static_cast<YDataType*>(p_y));
}; };
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
......
...@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) { auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
std::size_t N = arg.output_.GetLengths()[1];
std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[4];
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho) for(std::size_t ho = 0; ho < Ho; ++ho)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo) for(std::size_t wo = 0; wo < Wo; ++wo)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
......
...@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
size_t M = acc.mDesc.GetLengths()[0]; size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1]; size_t N = acc.mDesc.GetLengths()[1];
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc_sq({M});
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc({M});
Tensor<ComputeDataType> acc_layernorm(acc); Tensor<ComputeDataType> acc_layernorm(acc);
// reduce N dim // reduce N dim
......
...@@ -95,6 +95,7 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -95,6 +95,7 @@ struct ReferenceLayernorm : public device::BaseOperator
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_); auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_);
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
} }
} }
......
...@@ -60,6 +60,12 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -60,6 +60,12 @@ struct ReferenceSoftmax : public device::BaseOperator
{ {
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]); scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]);
} }
// max and sum reduction with final reduced values of dim=0 is a scalar so give it
// appropriate lengths of {1}
if(arg.sm_scalar_dims_.size() == 0)
{
scalar_lengths.push_back(1);
}
Tensor<AccDataType> reduce_max(scalar_lengths); Tensor<AccDataType> reduce_max(scalar_lengths);
reduce_max.GenerateTensorValue( reduce_max.GenerateTensorValue(
...@@ -67,6 +73,9 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -67,6 +73,9 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<AccDataType> reduce_sum(scalar_lengths); Tensor<AccDataType> reduce_sum(scalar_lengths);
reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
// when final reduced values is of dim=0, the index will be transformed into empty
// std::vector which is actually a valid input for Tensor::operator(std::vector) and
// internally accesses 0'th element
auto to_sm_scalar_idx = [&](auto idx) { auto to_sm_scalar_idx = [&](auto idx) {
std::vector<size_t> sm_scalar_idx; std::vector<size_t> sm_scalar_idx;
for(index_t dim : arg.sm_scalar_dims_) for(index_t dim : arg.sm_scalar_dims_)
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#pragma once #pragma once
#include <cstdlib> #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -15,6 +18,8 @@ using F64 = double; ...@@ -15,6 +18,8 @@ using F64 = double;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using I8 = int8_t;
using I32 = int32_t;
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
...@@ -23,6 +28,8 @@ using F16_F16_Tuple = ck::Tuple<F16, F16>; ...@@ -23,6 +28,8 @@ using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F32_Tuple = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -70,13 +77,25 @@ using NWGK = ck::tensor_layout::convolution::NWGK; ...@@ -70,13 +77,25 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK; using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
//
using GK = ck::tensor_layout::convolution::G_K;
using GK_TUPLE = ck::Tuple<GK>;
// pointwise functor // pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
template <typename DeviceOp> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory; struct DeviceOperationInstanceFactory;
} // namespace instance } // namespace instance
......
...@@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g ...@@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); false>>>& instances);
void add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row,
Col,
Row,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
true>>>& instances);
template <typename ALayout, template <typename ALayout,
typename B0Layout, typename B0Layout,
...@@ -39,7 +56,8 @@ template <typename ALayout, ...@@ -39,7 +56,8 @@ template <typename ALayout,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType> typename CDataType,
bool MaskOutUpperTriangle>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout, B0Layout,
...@@ -51,9 +69,10 @@ struct DeviceOperationInstanceFactory< ...@@ -51,9 +69,10 @@ struct DeviceOperationInstanceFactory<
CDataType, CDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>> MaskOutUpperTriangle>>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout, using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout, B0Layout,
...@@ -65,9 +84,10 @@ struct DeviceOperationInstanceFactory< ...@@ -65,9 +84,10 @@ struct DeviceOperationInstanceFactory<
CDataType, CDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>; MaskOutUpperTriangle>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -78,11 +98,19 @@ struct DeviceOperationInstanceFactory< ...@@ -78,11 +98,19 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>) is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{
if constexpr(MaskOutUpperTriangle)
{
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
else
{ {
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs); op_ptrs);
} }
} }
}
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -17,63 +17,89 @@ namespace tensor_operation { ...@@ -17,63 +17,89 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
template <ck::index_t... Is> void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
using S = ck::Sequence<Is...>; std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
using CPermuteNumDims_G_M_O = 1,
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O 1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row, std::vector<
Col, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
Row, 1,
CPermuteNumDims_G_M_O, 1,
1,
1,
F16, F16,
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ALayout, template <typename ADataType,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N,
typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType> typename CDataType,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
B0Layout, 1,
B1Layout, 1,
CPermuteNumDims_G_M_Gemm1N, 1,
1,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ck::Tuple<>,
ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>> PassThrough,
MaskingSpec>>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<ALayout, using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
B0Layout, 1,
B1Layout, 1,
CPermuteNumDims_G_M_Gemm1N, 1,
1,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ck::Tuple<>,
ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>; PassThrough,
MaskingSpec>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -82,11 +108,14 @@ struct DeviceOperationInstanceFactory< ...@@ -82,11 +108,14 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> && if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
is_same_v<B1Layout, Row> && {
is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>) add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{ {
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
} }
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward weight
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdWeight<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvBwdWeight<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_elementwise_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceElementwiseNormalization<ck::Tuple<F16, F16>,
F16,
F16,
F32,
F16,
element_wise::Add,
PassThrough,
2,
1>>>&);
template <typename InDataTypeTuple,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElementwiseNormalization<
InDataTypeTuple,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Add,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceElementwiseNormalization<InDataTypeTuple,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Add,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<GammaDataType, F16> && is_same_v<BetaDataType, F16> &&
is_same_v<YDataType, F16>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_elementwise_normalization_rank_2_1_f16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv2d backward data
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
typename InLayout,
typename OutDataType,
typename WeiDataType,
typename InDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<
NumDimSpatial,
OutLayout,
WeiLayout,
Empty_Tuple,
InLayout,
OutDataType,
WeiDataType,
Empty_Tuple,
InDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp =
DeviceGroupedConvBwdDataMultipleD<NumDimSpatial,
OutLayout,
WeiLayout,
Empty_Tuple,
InLayout,
OutDataType,
WeiDataType,
Empty_Tuple,
InDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward weight
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvBwdWeight<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvBwdWeight<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_conv2d_bias_perlayer_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_TUPLE,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_TUPLE,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>&
instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_TUPLE> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_bias_perlayer_quantization_int8_instances(op_ptrs);
else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
#pragma once #pragma once
#include <cstdlib> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwd<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment