Commit 2c1ed8b2 authored by Anthony Chang's avatar Anthony Chang
Browse files

Merge remote-tracking branch 'upstream/develop' into gemm-layernorm-4

parents b86b318b 56adf7e9
#ifndef CK_ENABLE_IF_HPP
#define CK_ENABLE_IF_HPP
#pragma once
namespace ck {
......@@ -10,4 +9,3 @@ template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck
#endif
......@@ -28,6 +28,7 @@
#include "config.hpp"
#include "data_type.hpp"
#include "type.hpp"
namespace ck {
......@@ -54,21 +55,30 @@ namespace reduce {
// accumulated index also need be
// changed.
template <class T>
struct Add
{
using dataType = T;
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
};
template <class T>
......@@ -76,7 +86,7 @@ struct SquaredAdd
{
using dataType = T;
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__host__ __device__ static constexpr T GetIdentityValue() { return type_convert<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -85,50 +95,78 @@ struct SquaredAdd
operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b * b; }
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
a = a + b * b;
}
};
template <class T>
struct Mul
{
using dataType = T;
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(1.0f);
};
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
};
template <class T>
struct Max
{
using dataType = T;
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return NumericLimits<T>::Lowest();
};
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
{
a = b;
......@@ -137,28 +175,41 @@ struct Max
}
};
template <class T>
struct Min
{
using dataType = T;
__host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return NumericLimits<T>::Max();
};
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_min to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b)
{
a = b;
......@@ -167,28 +218,41 @@ struct Min
}
};
template <class T>
struct AMax
{
using dataType = T;
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b)
{
a = b;
......@@ -198,7 +262,7 @@ struct AMax
};
template <typename T>
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
T result = ck::type_convert<T>(0.0f);
......@@ -208,6 +272,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation
return (result);
};
template <InMemoryDataOperationEnum Operation, typename DataType>
struct InMemoryDataOperatonSupportedOnDataType
{
static constexpr bool value = false;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
is_same<DataType, int32_t>::value;
};
}; // end of namespace reduce
} // end of namespace ck
......
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP
#pragma once
#include "integral_constant.hpp"
#include "type.hpp"
......@@ -241,7 +240,13 @@ struct arithmetic_sequence_gen
}
};
using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
using type1 = Sequence<>;
static constexpr bool kHasContent =
(Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);
using type = typename conditional<kHasContent, type0, type1>::type;
};
// uniform sequence
......@@ -882,5 +887,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
return flag;
}
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;
template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck
#endif
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP
#pragma once
#include "integral_constant.hpp"
#include "sequence.hpp"
......@@ -17,14 +16,18 @@ struct TupleElementKey
};
template <typename Key, typename Data>
struct TupleElement
struct TupleElementKeyData
{
__host__ __device__ constexpr TupleElement() = default;
#if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
__host__ __device__ constexpr TupleElementKeyData() : mData{} {}
#endif
template <
typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElement>::value, bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
template <typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
{
}
......@@ -32,20 +35,21 @@ struct TupleElement
};
template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
__host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
return static_cast<const Data&>(x.mData);
}
template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
__host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData<Key, Data>& x)
{
return x.mData;
}
// TODO: not sure the use of reference is correct
template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
__host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData<Key, Data>&& x)
{
return static_cast<Data&&>(x.mData);
}
......@@ -54,7 +58,7 @@ template <typename Indices, typename... Xs>
struct TupleImpl;
template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<Is>, Xs>...
{
__host__ __device__ constexpr TupleImpl() = default;
......@@ -63,13 +67,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
!is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{
}
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size");
......@@ -78,15 +82,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
__host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
{
return get_tuple_element<TupleElementKey<I>>(*this);
return get_tuple_element_data<TupleElementKey<I>>(*this);
}
template <index_t I>
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
__host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
{
return get_tuple_element<TupleElementKey<I>>(*this);
return get_tuple_element_data<TupleElementKey<I>>(*this);
}
};
......@@ -121,7 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr const auto& At(Number<I>) const
{
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
return base::GetElementDataByKey(detail::TupleElementKey<I>{});
}
// write access
......@@ -129,7 +133,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr auto& At(Number<I>)
{
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
return base::GetElementDataByKey(detail::TupleElementKey<I>{});
}
// read access
......@@ -159,6 +163,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template <>
struct Tuple<>
{
__host__ __device__ constexpr Tuple() = default;
__host__ __device__ static constexpr index_t Size() { return 0; }
template <typename T>
__host__ __device__ constexpr auto operator=(const T&)
{
return *this;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template <index_t I, typename TTuple>
struct tuple_element
{
using type = decltype(TTuple{}.At(Number<I>{}));
};
template <index_t I, typename TTuple>
using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
......@@ -173,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
}
} // namespace ck
#endif
#ifndef CK_TUPLE_HELPER_HPP
#define CK_TUPLE_HELPER_HPP
#pragma once
#include "functional4.hpp"
#include "tuple.hpp"
......@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
const Tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
namespace detail {
template <typename F, typename X, index_t... Is>
......@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
}
} // namespace ck
#endif
......@@ -174,15 +174,18 @@ struct ReductionHost
const InDataType* in_data,
float beta,
OutDataType* out_data,
IndexDataType* out_indices)
IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{
if constexpr(OutputIndex)
{
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
RunImpl_with_index(
alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op);
}
else
{
RunImpl_no_index(alpha, in_data, beta, out_data);
RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op);
};
};
......@@ -190,7 +193,9 @@ struct ReductionHost
const InDataType* in_data,
float beta,
OutDataType* out_data,
IndexDataType* out_indices)
IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{
using ck::float_equal_one;
using ck::float_equal_zero;
......@@ -200,12 +205,10 @@ struct ReductionHost
ReduceOperation,
AccDataType,
IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOperation::GetIdentityValue();
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
......@@ -236,7 +239,7 @@ struct ReductionHost
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::GetIdentityValue();
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
auto offset_invariant =
......@@ -297,7 +300,12 @@ struct ReductionHost
};
};
void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data)
void RunImpl_no_index(float alpha,
const InDataType* in_data,
float beta,
OutDataType* out_data,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{
using ck::float_equal_one;
using ck::float_equal_zero;
......@@ -306,12 +314,9 @@ struct ReductionHost
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOperation::GetIdentityValue();
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(const auto& reduce_index : reduce_dim_indexes)
{
......@@ -338,7 +343,7 @@ struct ReductionHost
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::GetIdentityValue();
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
......
......@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
arg.in_element_op_(v_acc, v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
......
......@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
for(int k = 0; k < K; ++k)
{
arg.a_element_op_(a, arg.a_m_k_(m, k));
arg.b_element_op_(b, arg.b_k_n_(k, n));
arg.a_element_op_(a, ck::type_convert<AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(b, ck::type_convert<AccDataType>(arg.b_k_n_(k, n)));
acc += a * b;
}
......
#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP
#define CK_DEVICE_OPERATION_INSTANCE_HPP
#pragma once
#include <stdlib.h>
#include <vector>
namespace ck {
namespace tensor_operation {
......@@ -23,4 +22,3 @@ void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>;
#endif
template <typename AccDataType, ReduceTensorOp ReduceOpId>
template <ReduceTensorOp ReduceOpId>
using deviceReduceBlockWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
......@@ -75,14 +75,13 @@ template <typename InDataType,
bool PropagateNan,
bool UseIndex>
void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>>& device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
......@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......@@ -150,21 +149,17 @@ void add_device_reduce_instance_blockwise(
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
......@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
>;
#endif
template <typename AccDataType, ReduceTensorOp ReduceOperation>
using deviceReduceMultiBlockAtomicAddPtrType =
DeviceReducePtr<typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
AccElementwiseOperation>;
template <ReduceTensorOp ReduceOperation>
using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr<
typename reduce_unary_operator<ReduceOperation, true, true>::InElementwiseOperation,
typename reduce_unary_operator<ReduceOperation, true, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
......@@ -77,15 +75,13 @@ template <typename InDataType,
bool PropagateNan,
bool UseIndex>
void add_device_reduce_instance_multiblock_atomic_add(
std::vector<deviceReduceMultiBlockAtomicAddPtrType<AccDataType, ReduceOpId>>&
device_op_instances)
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>>& device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
......@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
device_op_instances)
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......@@ -172,21 +167,17 @@ void add_device_reduce_instance_multiblock_atomic_add(
Rank, \
NumReduceDim)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
......@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple<
>;
#endif
template <typename AccDataType, ReduceTensorOp ReduceOpId>
template <ReduceTensorOp ReduceOpId>
using deviceReduceThreadWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
......@@ -61,14 +61,13 @@ template <typename InDataType,
bool PropagateNan,
bool UseIndex>
void add_device_reduce_instance_threadwise(
std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>>& device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
......@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......@@ -127,21 +126,17 @@ void add_device_reduce_instance_threadwise(
Rank, \
NumReduceDim)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
......@@ -20,7 +20,7 @@ include_directories(BEFORE
function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}")
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endfunction(add_instance_library INSTANCE_NAME)
......@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d)
add_subdirectory(gemm_bias_relu)
add_subdirectory(gemm_bias_relu_add)
add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd)
......@@ -43,13 +44,14 @@ add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_gemm)
add_subdirectory(conv2d_bwd_weight)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(gemm_add_add_fastgelu)
add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_atomic_add_instance>
$<TARGET_OBJECTS:device_gemm_instance>
......@@ -62,19 +64,20 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
device_conv2d.cpp
)
add_library(composablekernels::device_operations ALIAS device_operations)
set(DEV_OPS_INC_DIRS
set(DEV_OPS_INC_DIRS
${PROJECT_SOURCE_DIR}/include/ck/
${PROJECT_SOURCE_DIR}/library/include/ck/
${PROJECT_SOURCE_DIR}/external/include/
)
target_compile_features(device_operations PUBLIC)
set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(device_operations PUBLIC
target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/utility>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_description>
......@@ -96,9 +99,11 @@ target_include_directories(device_operations PUBLIC
#once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported
target_compile_options(device_operations
PRIVATE --offload-arch=gfx908
target_compile_options(device_operations PRIVATE
--offload-arch=gfx908
--offload-arch=gfx90a
)
# install(TARGETS device_operations LIBRARY DESTINATION lib)
install(TARGETS device_operations
EXPORT device_operationsTargets
......@@ -108,8 +113,8 @@ install(TARGETS device_operations
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
......@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
......@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
# device_gemm_add_add_fastgelu_instance
set(DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
)
add_library(device_gemm_add_add_fastgelu_instance OBJECT ${DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE})
target_compile_features(device_gemm_add_add_fastgelu_instance PUBLIC)
set_target_properties(device_gemm_add_add_fastgelu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_gemm_add_add_fastgelu_instance)
#include <stdlib.h>
#include "config.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using F16_F16 = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// e = elementwise((a * b), d)
// outout: e[m, n]
// input: a[k, m], b[k, n], d[m, n]
using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format off
//##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmMultipleDPtr<2, PassThrough, PassThrough, AddAddFastGelu>>& instances)
{
add_device_operation_instances(
instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using F16_F16 = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// e = elementwise((a * b), d)
// outout: e[m, n]
// input: a[k, m], b[n, k], d[m, n]
using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off
//##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmMultipleDPtr<2, PassThrough, PassThrough, AddAddFastGelu>>& instances)
{
add_device_operation_instances(
instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment