Commit c5138aa1 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents 7830272f 82f3a835
......@@ -9,15 +9,9 @@ namespace ck {
using bhalf_t = ushort;
using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4);
#endif
#if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
// vector_type
template <typename T, index_t N>
......@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8
template <>
struct scalar_type<f8_t>
{
using type = f8_t;
static constexpr index_t vector_size = 1;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
#endif
template <typename T>
struct vector_type<T, 1>
......@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T>
struct NumericLimits
......@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <>
struct NumericLimits<f8_t>
{
......@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericLimits<bf8_t>
{
......@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
};
#endif
template <typename T>
struct NumericUtils
......@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t>
using bitwise_type = uint16_t;
};
#if defined CK_ENABLE_FP8
template <>
struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
};
#endif
#if defined CK_ENABLE_BF8
template <>
struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
};
#endif
//
} // namespace ck
......@@ -6,8 +6,6 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck {
// fp8 rounding modes
......@@ -244,5 +242,3 @@ __host__ __device__ Y cast_from_f8(X x)
}
} // namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
......@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return min(max(x, lowerbound), upperbound);
}
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
static inline __host__ float exp(float x) { return std::expf(x); }
static inline __host__ double exp(double x) { return std::exp(x); }
// greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{
......
......@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck {
namespace math {
......@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ half_t tanh(half_t x)
template <typename T>
inline __host__ T tanh(T x)
{
return static_cast<half_t>(std::tanh(static_cast<float>(x)));
return ck::type_convert<T>(std::tanhf(ck::type_convert<float>(x)));
};
static inline __host__ float tanh(float x) { return std::tanh(x); };
template <>
inline __host__ float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
inline __host__ double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
inline __host__ T exp(T x)
{
return ck::type_convert<T>(std::expf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float exp<float>(float x)
{
return std::expf(x);
}
static inline __host__ double tanh(double x) { return std::tanh(x); };
template <>
inline __host__ double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
inline __host__ T log(T x)
{
return ck::type_convert<T>(std::logf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float log<float>(float x)
{
return std::logf(x);
}
template <>
inline __host__ double log<double>(double x)
{
return std::log(x);
}
template <typename T>
inline __host__ T pow(T x, T gamma)
{
return ck::type_convert<T>(
std::powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
}
template <>
inline __host__ float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
inline __host__ double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
inline __host__ T expm1(T x)
{
return ck::type_convert<T>(std::expm1f(ck::type_convert<float>(x)));
}
template <>
inline __host__ float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
inline __host__ double expm1<double>(double x)
{
return std::expm1(x);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
......@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ half_t tanh(half_t x)
template <typename T>
inline __device__ T tanh(T x)
{
return ck::type_convert<T>(::tanhf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float tanh<float>(float x)
{
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
return ::tanhf(x);
};
static inline __device__ float tanh(float x) { return ::tanhf(x); };
template <>
inline __device__ double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
inline __device__ T exp(T x)
{
return ck::type_convert<T>(__expf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t exp<half_t>(half_t x)
{
return hexp(x);
};
template <>
inline __device__ float exp<float>(float x)
{
return __expf(x);
};
static inline __device__ double tanh(double x) { return ::tanh(x); };
template <>
inline __device__ double exp<double>(double x)
{
return exp(x);
};
template <typename T>
inline __device__ T log(T x)
{
return ck::type_convert<T>(__logf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t log<half_t>(half_t x)
{
return hlog(x);
};
template <>
inline __device__ float log<float>(float x)
{
return __logf(x);
};
template <>
inline __device__ double log<double>(double x)
{
return log(x);
};
template <typename T>
inline __device__ T pow(T x, T gamma)
{
return ck::type_convert<T>(powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
};
template <>
inline __device__ float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
template <>
inline __device__ double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
inline __device__ T expm1(T x)
{
return ck::type_convert<T>(expm1f(ck::type_convert<float>(x)));
};
template <>
inline __device__ float expm1<float>(float x)
{
return expm1f(x);
};
template <>
inline __device__ double expm1<double>(double x)
{
return expm1(x);
};
} // namespace math
} // namespace ck
......@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace ck {
......
......@@ -95,7 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
#if defined CK_ENABLE_FP8
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
......@@ -173,9 +172,7 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
return type_convert<half_t>(type_convert<float>(x));
#endif
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
......@@ -253,7 +250,6 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
return type_convert<half_t>(type_convert<float>(x));
#endif
}
#endif
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
......@@ -316,7 +312,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
......@@ -365,9 +360,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
return f8_convert_sr<f8_t>(type_convert<float>(x));
#endif
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
......@@ -417,6 +410,5 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return f8_convert_sr<bf8_t>(type_convert<float>(x));
#endif
}
#endif
} // namespace ck
......@@ -128,11 +128,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc));
arg.output_(g, n, k, wo) = v_out;
};
make_ParallelTensorFunctor(func,
......@@ -184,11 +182,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc));
arg.output_(g, n, k, ho, wo) = v_out;
};
make_ParallelTensorFunctor(func,
......@@ -253,11 +249,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc));
arg.output_(g, n, k, d_o, ho, wo) = v_out;
};
make_ParallelTensorFunctor(func,
......
......@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation>
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation>
struct ReferenceGroupnorm : public device::BaseOperator
{
// x = [N, H, W, G, C]
......@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op,
Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths,
AccDataType epsilon)
ComputeDataType epsilon)
: x_(x),
gamma_(gamma),
beta_(beta),
y_(y),
acc_elementwise_op_(acc_elementwise_op),
save_mean_(save_mean),
save_inv_std_(save_inv_std),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths),
epsilon_(epsilon)
{
......@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<XDataType> gamma_;
const Tensor<XDataType> beta_;
Tensor<YDataType>& y_;
AccElementwiseOperation acc_elementwise_op_;
Tensor<SaveMeanInvStdDataType>& save_mean_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_;
AccDataType epsilon_;
ComputeDataType epsilon_;
};
// Invoker
......@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
int G = arg.lengths_[3];
int C = arg.lengths_[4];
Tensor<AccDataType> mean({N, G});
Tensor<AccDataType> var({N, G});
Tensor<ComputeDataType> mean({N, G});
Tensor<ComputeDataType> var({N, G});
// Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC
......@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
{
for(int g = 0; g < G; ++g)
{
AccDataType mean_val = type_convert<AccDataType>(0.0f);
AccDataType var_val = type_convert<AccDataType>(0.0f);
int32_t curr_count = 0;
ComputeDataType mean_val = type_convert<ComputeDataType>(0.0f);
ComputeDataType var_val = type_convert<ComputeDataType>(0.0f);
int32_t curr_count = 0;
for(int h = 0; h < H; ++h)
{
......@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
for(int c = 0; c < C; ++c)
{
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c));
AccDataType delta = x - mean_val;
ComputeDataType x =
type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
ComputeDataType delta = x - mean_val;
mean_val += delta / curr_count;
AccDataType delta2 = x - mean_val;
ComputeDataType delta2 = x - mean_val;
var_val += delta * delta2;
}
}
......@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
mean(n, g) = mean_val;
var(n, g) = var_val / curr_count;
arg.save_mean_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(mean(n, g));
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n, g) + arg.epsilon_);
arg.save_inv_std_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
}
}
......@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
{
for(int c = 0; c < C; ++c)
{
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c));
AccDataType gamma = type_convert<AccDataType>(arg.gamma_(g, c));
AccDataType beta = type_convert<AccDataType>(arg.beta_(g, c));
AccDataType mean_val = type_convert<AccDataType>(mean(n, g));
AccDataType var_val = type_convert<AccDataType>(var(n, g));
AccDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.acc_elementwise_op_(y, y);
ComputeDataType x =
type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
ComputeDataType gamma =
type_convert<ComputeDataType>(arg.gamma_(g, c));
ComputeDataType beta =
type_convert<ComputeDataType>(arg.beta_(g, c));
ComputeDataType mean_val =
type_convert<ComputeDataType>(mean(n, g));
ComputeDataType var_val = type_convert<ComputeDataType>(var(n, g));
ComputeDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.y_elementwise_op_(y, y);
arg.y_(n, h, w, g, c) = type_convert<YDataType>(y);
}
}
......@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op,
Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths,
AccDataType epsilon)
ComputeDataType epsilon)
{
return Argument{x, gamma, beta, y, acc_elementwise_op, lengths, epsilon};
return Argument{
x, gamma, beta, y, save_mean, save_inv_std, y_elementwise_op, lengths, epsilon};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename SaveMeanInvStdDataType,
typename ComputeDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
struct ReferenceLayernorm : public device::BaseOperator
......@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op,
Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims,
AccDataType epsilon)
ComputeDataType epsilon)
: x_m_n_(x_m_n),
gamma_n_(gamma_n),
beta_n_(beta_n),
y_m_n_(y_m_n),
acc_elementwise_op_(acc_elementwise_op),
save_mean_m_(save_mean_m),
save_inv_std_m_(save_inv_std_m),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths),
reduceDims_(reduceDims),
epsilon_(epsilon)
......@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<XDataType> gamma_n_;
const Tensor<XDataType> beta_n_;
Tensor<YDataType>& y_m_n_;
AccElementwiseOperation acc_elementwise_op_;
Tensor<SaveMeanInvStdDataType>& save_mean_m_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_m_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_;
std::vector<index_t> reduceDims_;
AccDataType epsilon_;
ComputeDataType epsilon_;
};
// Invoker
......@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator
int M = arg.lengths_[0];
int N = arg.lengths_[1];
Tensor<AccDataType> mean({M});
Tensor<AccDataType> var({M});
Tensor<ComputeDataType> mean({M});
Tensor<ComputeDataType> var({M});
for(int m = 0; m < M; ++m)
{
......@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int n = 0; n < N; ++n)
{
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
mean(m) += x_val;
var(m) += x_val * x_val;
}
......@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m)
{
AccDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n)
{
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val);
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
auto gamma_val = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(n));
auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * gamma_val) + beta_val;
arg.y_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
}
arg.save_mean_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(mean(m));
arg.save_inv_std_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
}
return 0;
......@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op,
Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims,
AccDataType epsilon)
ComputeDataType epsilon)
{
return Argument{
x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon};
return Argument{x_m_n,
gamma_n,
beta_n,
y_m_n,
save_mean_m,
save_inv_std_m,
y_elementwise_op,
lengths,
reduceDims,
epsilon};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -20,12 +20,8 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using I32 = int32_t;
#if defined CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#if defined CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using Empty_Tuple = ck::Tuple<>;
......
......@@ -240,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
......@@ -267,17 +269,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
}
#endif
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
#endif
#if defined(DL_KERNELS) && defined(CK_ENABLE_FP32)
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
......@@ -306,14 +314,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
}
#endif
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
......
......@@ -98,30 +98,31 @@ struct DeviceOperationInstanceFactory<
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
}
......
......@@ -155,7 +155,7 @@ struct DeviceOperationInstanceFactory<
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
is_same_v<CDataType, float> && is_same_v<ComputeType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -180,8 +180,8 @@ struct DeviceOperationInstanceFactory<
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -206,8 +206,8 @@ struct DeviceOperationInstanceFactory<
}
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -230,8 +230,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......
......@@ -627,8 +627,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
......@@ -637,9 +637,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
......@@ -650,8 +649,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
else if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, NWGK>)
if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, NWGK>)
{
#ifdef DL_KERNELS
#ifdef CK_ENABLE_FP32
......@@ -662,16 +661,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
op_ptrs);
......@@ -680,7 +678,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
}
}
else if constexpr(NumDimSpatial == 2)
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
......@@ -698,8 +696,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
......@@ -710,9 +708,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
......@@ -723,8 +720,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
......@@ -739,8 +736,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
......@@ -751,9 +748,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
......@@ -765,7 +761,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
}
}
else if constexpr(NumDimSpatial == 3)
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
......@@ -783,8 +779,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
......@@ -799,9 +795,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
......@@ -822,8 +817,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
......@@ -838,10 +833,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> &&
is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
......@@ -856,9 +850,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
......@@ -879,9 +872,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> &&
is_same_v<ComputeTypeA, bf8_t> && is_same_v<ComputeTypeB, f8_t>)
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
op_ptrs);
......
......@@ -19,13 +19,13 @@ namespace instance {
#ifdef CK_ENABLE_FP16
// FP16
void add_device_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&);
void add_device_normalization_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 4, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&);
void add_device_normalization_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
......@@ -42,14 +42,15 @@ template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormalization<
XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>>
......@@ -57,8 +58,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
using DeviceOp = DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>;
......@@ -68,7 +69,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
......@@ -86,7 +88,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
......
......@@ -19,7 +19,7 @@ namespace instance {
// FP16
void add_device_normalization_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Swish, 5, 3>>>&);
// FP32
void add_device_normalization_rank_5_3_swish_f32_instances(
......@@ -27,20 +27,21 @@ void add_device_normalization_rank_5_3_swish_f32_instances(
// [x, gamma, beta, y] = [f16, f32, f32, f16]
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&);
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F16, F32, Swish, 5, 3>>>&);
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::Swish,
Rank,
NumReduceDim>>
......@@ -48,8 +49,8 @@ struct DeviceOperationInstanceFactory<
using DeviceOp = DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::Swish,
Rank,
NumReduceDim>;
......@@ -59,7 +60,8 @@ struct DeviceOperationInstanceFactory<
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
......@@ -67,7 +69,8 @@ struct DeviceOperationInstanceFactory<
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
......@@ -75,7 +78,8 @@ struct DeviceOperationInstanceFactory<
}
}
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F16>)
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
......
......@@ -230,7 +230,6 @@ check_err(const Range& out,
return res;
}
#if defined CK_ENABLE_FP8
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f8_t>),
......@@ -275,9 +274,7 @@ check_err(const Range& out,
}
return res;
}
#endif
#if defined CK_ENABLE_BF8
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
......@@ -322,7 +319,6 @@ check_err(const Range& out,
}
return res;
}
#endif
} // namespace utils
} // namespace ck
......@@ -5,44 +5,44 @@ function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}")
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
set(type1 "_i8")
endif()
#make an exception for reduction kernels
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance" OR ${source} MATCHES "device_image_to_column")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
endforeach()
if(test EQUAL 1)
message("removing instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
......@@ -52,8 +52,10 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 0)
message("add_instance_library ${INSTANCE_NAME}")
else()
message("skip_instance_libary ${INSTANCE_NAME}")
endif()
#message("add_instance_library returns ${result}")
set(result ${result} PARENT_SCOPE)
endfunction(add_instance_library INSTANCE_NAME)
......@@ -61,65 +63,70 @@ endfunction(add_instance_library INSTANCE_NAME)
file(GLOB dir_list LIST_DIRECTORIES true *)
set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list})
set(target_dir)
IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
set(target_dir)
IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
message("fp8 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
endif()
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
message("fp16 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
endif()
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
message("fp32 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
endif()
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
endif()
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
message("bf16 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
endif()
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
message("int8 instance found!")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "_fp8" OR
NOT "${cmake_instance}" MATCHES "_f8" OR
NOT "${cmake_instance}" MATCHES "_fp16" OR
NOT "${cmake_instance}" MATCHES "_f16" OR
NOT "${cmake_instance}" MATCHES "_fp32" OR
NOT "${cmake_instance}" MATCHES "_f32" OR
NOT "${cmake_instance}" MATCHES "_fp64" OR
NOT "${cmake_instance}" MATCHES "_f64" OR
NOT "${cmake_instance}" MATCHES "_bf16" OR
NOT "${cmake_instance}" MATCHES "_int8" OR
NOT "${cmake_instance}" MATCHES "_i8" OR
NOT "${cmake_instance}" MATCHES "_int4" OR
NOT DEFINED DTYPES)
message("instance should be built for all types!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
message("quantization instances will not be built!")
set(add_inst 0)
endif()
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
endif()
if(NOT ("${cmake_instance}" MATCHES "_fp8" OR
"${cmake_instance}" MATCHES "_f8" OR
"${cmake_instance}" MATCHES "_fp16" OR
"${cmake_instance}" MATCHES "_f16" OR
"${cmake_instance}" MATCHES "_fp32" OR
"${cmake_instance}" MATCHES "_f32" OR
"${cmake_instance}" MATCHES "_fp64" OR
"${cmake_instance}" MATCHES "_f64" OR
"${cmake_instance}" MATCHES "_bf16" OR
"${cmake_instance}" MATCHES "_int8" OR
"${cmake_instance}" MATCHES "_i8" OR
"${cmake_instance}" MATCHES "_int4"))
message("instance should be built for all types!")
set(add_inst 1)
endif()
if(NOT DEFINED DTYPES)
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
message("quantization instances will not be built!")
set(add_inst 0)
endif()
if(add_inst EQUAL 1)
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
endif()
ENDIF()
endif()
if(("${cmake_instance}" MATCHES "ONLY DL_KERNELS") AND (NOT DEFINED DL_KERNELS))
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
if((add_inst EQUAL 1))
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
message("add_instance_directory ${subdir_path}")
else()
message("skip_instance_directory ${subdir_path}")
endif()
ENDIF()
ENDFOREACH()
add_library(device_operations STATIC ${CK_DEVICE_INSTANCES})
......@@ -162,11 +169,11 @@ target_compile_options(device_operations PRIVATE
# install(TARGETS device_operations LIBRARY DESTINATION lib)
rocm_install(TARGETS device_operations
EXPORT device_operationsTargets)
EXPORT device_operationsTargets)
rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
rocm_install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
......@@ -96,13 +96,9 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instan
add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
set(ENABLE_PIPELINE_V2_OPT OFF)
set(ENABLE_PIPELINE_V2_OPT)
if (ENABLE_PIPELINE_V2_OPT)
set(MAX_ILP_OPTS
-mllvm
-amdgpu-enable-max-ilp-scheduling-strategy
)
set(WAVES_PER_EU_DEFS
CK_USE_WAVES_PER_EU=1
CK_MIN_WAVES_PER_EU=1
......@@ -118,7 +114,7 @@ if (ENABLE_PIPELINE_V2_OPT)
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
# layout=NN
set_source_files_properties(device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
COMPILE_OPTIONS "${MAX_ILP_OPTS}"
COMPILE_OPTIONS ";;"
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
# layout=TT
set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
......@@ -126,7 +122,7 @@ if (ENABLE_PIPELINE_V2_OPT)
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
# layout=TN
set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
COMPILE_OPTIONS "${MAX_ILP_OPTS}"
COMPILE_OPTIONS ";;"
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
endif(ENABLE_PIPELINE_V2_OPT)
add_instance_library(device_grouped_conv3d_bwd_data_instance
set(GROUPED_CONV3D_BWD_DATA
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
......@@ -13,5 +12,11 @@ add_instance_library(device_grouped_conv3d_bwd_data_instance
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
)
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_CONV3D_BWD_DATA
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp)
endif()
add_instance_library(device_grouped_conv3d_bwd_data_instance ${GROUPED_CONV3D_BWD_DATA})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment