"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "c38ecd2eeac6d8497160343ba8080d3b9e780bd3"
Commit 39274378 authored by raman.jana's avatar raman.jana
Browse files

assembly functions for softmax primitives

parent 38f48480
......@@ -130,4 +130,143 @@ struct BlockwiseSoftmax
BufferType sum_value_buf;
};
template <index_t BlockSize,
typename AccDataType,
typename ThreadMap_M_K, // thread_id to m_k
typename ThreadClusterDesc_M_K,
typename ThreadSliceDesc_M_K,
bool IgnoreNaN = false>
struct BlockwiseSoftmax_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
using ThreadSliceDesc_M = decltype(
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
using ThreadwiseMaxReduce = typename conditional<
IgnoreNaN,
ThreadwiseReductionDouble<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max3,
false,
detail::AccumulateWithNanIgnore<reduce::Max3, AccDataType>>,
ThreadwiseReductionDouble<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max3,
false>>::type;
using ThreadwiseSumReduce = typename conditional<
IgnoreNaN,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::fast_Add,
false,
detail::AccumulateWithNanIgnore<reduce::fast_Add, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::fast_Add,
false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadMap_M_K,
reduce::fast_Max,
false>;
using BlockwiseSumReduce = PartitionedBlockwiseReduction_v2<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadMap_M_K,
reduce::Add,
false>;
using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>;
template <typename CThreadBuffer, typename WorkspaceBuffer>
__host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf)
{
// find max value
static_for<0, MRepeat, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
});
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseMaxReduce::WaveReduce(reduce_work_buf, max_value_buf(I));
});
if (IgnoreNaN)
{
// calculate exp for elements, P=exp(s-max)
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf(iM));
});
});
}
else
{
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = math::exp(in_thread_buf[offset] - max_value_buf(iM));
});
});
}
// sum data
static_for<0, MRepeat, 1>{}([&](auto I) {
sum_value_buf(I) = reduce::fast_Add::template GetIdentityValue<AccDataType>();
});
ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseSumReduce::WaveReduce(reduce_work_buf, sum_value_buf(I));
});
}
template <typename CThreadBuffer, typename LSEBuffer>
__host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf,
const LSEBuffer& lse_thread_buf)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
if (IgnoreNaN)
{
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - lse_thread_buf[iM]);
});
});
}
else
{
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = math::exp(in_thread_buf[offset] - lse_thread_buf[iM]);
});
});
}
}
BufferType max_value_buf;
BufferType sum_value_buf;
};
} // namespace ck
......@@ -152,6 +152,44 @@ struct PartitionedBlockwiseReduction_v2
in_out_value = work_buffer[offset];
};
template <typename BufferType>
__device__ static void WaveReduce(BufferType& work_buffer, AccDataType& in_out_value)
{
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
lds_waitcnt(0);
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset)
{
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
AccDataType opData1 = work_buffer[offset1];
AccDataType opData2 = work_buffer[offset2];
Accumulation::Calculate(opData1, opData2);
work_buffer(offset1) = opData1;
}
lds_waitcnt(0);
});
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
in_out_value = work_buffer[offset];
};
};
// clang-format off
......
......@@ -47,6 +47,46 @@ struct ThreadwiseReduction
};
};
// Assume
// 1) SrcDesc is known at compile-time
// 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer
template <typename AccDataType,
typename SrcThreadDesc_M_K,
typename DstThreadDesc_M,
typename OpReduce,
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
struct ThreadwiseReductionDouble
{
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
static constexpr auto dst_thread_desc_m = DstThreadDesc_M{};
static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Op = OpReduce;
template <typename SrcBufferType, typename DstBufferType>
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
{
static_for<0, src_length_m, 1>{}([&](auto iM) {
constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
static_for<0, src_length_k, 2>{}([&](auto iK) {
constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset1 = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK+1));
Accumulation::Calculate(dst_buf(Number<out_offset>{}), src_buf[Number<offset>{}],src_buf[Number<offset1>{}]);
});
});
};
};
// Assume
// 1) SrcDesc is known at compile-time
// 2) DstDesc is known at compile-time
......
......@@ -37,4 +37,38 @@ constexpr __device__ index_t get_shift<1>()
return (0);
}
template<typename T>
__host__ __device__ void waveReduceSum(T& src)
{
T val;
index_t sumVal = 0;
// = __builtin_amdgcn_readlane(src,63);
asm volatile("\n \
v_add_f32 %0, %1, %1 row_shr:1 bound_ctrl:0\n \
v_add_f32 %0, %1, %0 row_shr:2 bound_ctrl:0\n \
v_add_f32 %0, %1, %0 row_shr:3 bound_ctrl:0\n \
v_nop\n \
v_nop\n \
v_add_f32 %0, %0, %0 row_shr:4 bound_ctrl:0\n \
v_nop\n \
v_nop\n \
v_add_f32 %0, %0, %0 row_shr:8 bound_ctrl:0\n \
v_nop\n \
v_nop\n \
v_add_f32 %1, %0, %0 row_bcast:15 row_mask:0xa\n \
v_nop\n \
v_nop\n \
v_add_f32 %1, %1, %1 row_bcast:31 row_mask:0xc\n \
v_nop\n \
v_nop\n \
v_readlane_b32 %2, %1, 63\n \
v_nop\n \
v_nop\n \
v_mov_b32 %1, %2\n \
"
: "=v"(val)
: "v"(src), "s"(sumVal),
"0"(val));
}
} // namespace ck
......@@ -22,6 +22,13 @@ struct AccumulateWithNanIgnore
ReduceOperation{}(accuVal, currVal);
}
};
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal, AccDataType currVal1)
{
if(!ck::math::isnan(currVal) && !ck::math::isnan(currVal1))
{
ReduceOperation{}(accuVal, currVal, currVal1);
}
};
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
......@@ -40,6 +47,10 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{
ReduceOperation{}(accuVal, currVal);
};
__host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal, AccDataType currVal1)
{
ReduceOperation{}(accuVal, currVal,currVal1);
};
};
// Check for NaN; guarantees NaNs be propagated to result
......@@ -59,6 +70,23 @@ struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
ReduceOperation{}(accuVal, currVal);
};
};
__host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal, AccDataType currVal1)
{
using ck::math::isnan;
if(isnan(currVal))
{
accuVal = currVal;
}
else if(isnan(currVal1))
{
accuVal = currVal1;
}
else
{
ReduceOperation{}(accuVal, currVal, currVal1);
};
};
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType, typename IndexDataType>
......
......@@ -239,6 +239,519 @@ struct AMax
}
};
struct fast_Add
{
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T,half_t>::value,
"The data type is not supported by the Add accumulator!");
T c{1.0f};
if(is_same<T,float>::value)
{
asm volatile("\n \
v_fma_f32 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(c), "v"(b), "0"(a));
}
else if(is_same<T,half_t>::value)
{
asm volatile("\n \
v_fma_f16 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(c), "v"(b),"0"(a));
}
else if(is_same<T,double>::value)
{
asm volatile("\n \
v_fma_f64 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(c), "v"(b),"0"(a));
}
else
{
a = a + b;
}
}
};
struct fast_Sub
{
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T,half_t>::value,
"The data type is not supported by the Add accumulator!");
T c{-1.0f};
if(is_same<T,float>::value)
{
asm volatile("\n \
v_fma_f32 %0, %2, %1, %0\n \
"
: "=v"(a)
: "v"(c), "v"(b), "0"(a));
}
else if(is_same<T,half_t>::value)
{
asm volatile("\n \
v_fma_f16 %0, %2, %1, %0\n \
"
: "=v"(a)
: "v"(c), "v"(b),"0"(a));
}
else if(is_same<T,double>::value)
{
asm volatile("\n \
v_fma_f64 %0, %2, %1, %0\n \
"
: "=v"(a)
: "v"(c), "v"(b),"0"(a));
}
else
{
a = a - b;
}
}
};
struct Add2
{
template <typename T>
__host__ __device__ static T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b) const
{
static_assert(is_same<T, float2_t>::value || is_same<T, half2_t>::value,
"The data type is not supported by the Add accumulator!");
T c{1.0f};
if(is_same<T,float2_t>::value)
{
asm volatile("\n \
v_pk_fma_f32 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(c), "v"(b),
"0"(a));
}
else if(is_same<T,half2_t>::value)
{
asm volatile("\n \
v_pk_fma_f16 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(c), "v"(b),
"0"(a));
}
}
};
struct Sub2
{
template <typename T>
__host__ __device__ static T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b) const
{
static_assert(is_same<T, float2_t>::value || is_same<T, half2_t>::value,
"The data type is not supported by the Add accumulator!");
T c{-1.0f};
if(is_same<T,float2_t>::value)
{
asm volatile("\n \
v_pk_fma_f32 %0, %2, %1, %0\n \
"
: "=v"(a)
: "v"(c), "v"(b),
"0"(a));
}
else if(is_same<T,half2_t>::value)
{
asm volatile("\n \
v_pk_fma_f16 %0, %2, %1, %0\n \
"
: "=v"(a)
: "v"(c), "v"(b),
"0"(a));
}
}
};
struct Mul2
{
template <typename T>
__host__ __device__ static T GetIdentityValue()
{
return type_convert<T>(1.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b) const
{
static_assert(is_same<T, float2_t>::value || is_same<T, half2_t>::value,
"The data type is not supported by the Mul accumulator!");
if(is_same<T,float2_t>::value)
{
asm volatile("\n \
v_pk_mul_f32 %0, %0, %1\n \
"
: "=v"(a)
: "v"(b),
"0"(a));
}
else if(is_same<T,half_t>::value)
{
asm volatile("\n \
v_pk_mul_f16 %0, %0, %1\n \
"
: "=v"(a)
: "v"(b),
"0"(a));
}
}
};
struct fast_Max
{
template <typename T>
__host__ __device__ static T GetIdentityValue()
{
return NumericLimits<T>::Lowest();
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(is_same<T,float>::value)
{
asm volatile("\n \
v_max_f32 %0, %0, %1\n \
"
: "=v"(a)
: "v"(b),
"0"(a));
}
else if(is_same<T,half_t>::value)
{
asm volatile("\n \
v_max_f16 %0, %0, %1\n \
"
: "=v"(a)
: "v"(b),
"0"(a));
}
else
{
if(a < b)
a = b;
}
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
{
a = b;
changed = true;
}
}
};
struct Max3
{
template <typename T>
__host__ __device__ static T GetIdentityValue()
{
return NumericLimits<T>::Lowest();
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b, T c) const
{
static_assert(is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Max accumulator!");
if(is_same<T,float>::value)
{
asm volatile("\n \
v_max3_f32 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(b), "v"(c),
"0"(a));
}
else if(is_same<T,half_t>::value)
{
asm volatile("\n \
v_max3_f16 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(b), "v"(c),
"0"(a));
}
else
{
asm volatile("\n \
v_max3_i32 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(b), "v"(c),
"0"(a));
}
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
{
a = b;
changed = true;
}
}
};
struct fast_Min
{
template <typename T>
__host__ __device__ static T GetIdentityValue()
{
return NumericLimits<T>::Max();
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(is_same<T,float>::value)
{
asm volatile("\n \
v_min_f32 %0, %0, %1\n \
"
: "=v"(a)
: "v"(b),
"0"(a));
}
else if(is_same<T,half_t>::value)
{
asm volatile("\n \
v_min_f16 %0, %0, %1\n \
"
: "=v"(a)
: "v"(b),
"0"(a));
}
else if(is_same<T,double>::value)
{
asm volatile("\n \
v_min_f64 %0, %0, %1\n \
"
: "=v"(a)
: "v"(b),
"0"(a));
}
else if(is_same<T,int32_t>::value)
{
asm volatile("\n \
v_min_i32 %0, %0, %1\n \
"
: "=v"(a)
: "v"(b),
"0"(a));
}
else
{
if(a < b)
a = b;
}
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
{
a = b;
changed = true;
}
}
};
struct Min3
{
template <typename T>
__host__ __device__ static T GetIdentityValue()
{
return NumericLimits<T>::Max();
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline void operator()(T& a, T b, T c) const
{
static_assert(is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Max accumulator!");
if(is_same<T,float>::value)
{
asm volatile("\n \
v_min3_f32 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(b), "v"(c),
"0"(a));
}
else if(is_same<T,half_t>::value)
{
asm volatile("\n \
v_min3_f16 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(b), "v"(c),
"0"(a));
}
else
{
asm volatile("\n \
v_min3_i32 %0, %0, %1, %2\n \
"
: "=v"(a)
: "v"(b), "v"(c),
"0"(a));
}
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
{
a = b;
changed = true;
}
}
};
template <typename T>
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
......@@ -288,5 +801,6 @@ struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, D
is_same<DataType, int32_t>::value;
};
} // namespace reduce
} // namespace ck
......@@ -28,5 +28,102 @@ __device__ void s_nop()
__builtin_amdgcn_sched_barrier(0);
#endif
}
__device__ void wg_sync()
{
asm volatile("\
s_barrier \n \
"
:: );
}
__device__ void raise_priority()
{
asm volatile("\
s_setprio(3) \n \
"
:: );
}
__device__ void lower_priority()
{
asm volatile("\
s_setprio(0) \n \
"
:: );
}
__device__ void vm_waitcnt(const uint32_t cnt)
{
if (cnt == 0) {
asm volatile("\
s_waitcnt vmcnt(0) \n \
"
:: );
}
else if (cnt == 2) {
asm volatile("\
s_waitcnt vmcnt(2) \n \
"
:: );
}
else if (cnt == 4) {
asm volatile("\
s_waitcnt vmcnt(4) \n \
"
:: );
}
else if (cnt == 8) {
asm volatile("\
s_waitcnt vmcnt(8) \n \
"
:: );
}
else if (cnt == 12) {
asm volatile("\
s_waitcnt vmcnt(12) \n \
"
:: );
}
else {
asm volatile("\
s_waitcnt vmcnt(16) \n \
"
:: );
}
}
__device__ void lds_waitcnt(const uint32_t cnt)
{
if (cnt == 0) {
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
"
:: );
}
else if (cnt == 4) {
asm volatile("\
s_waitcnt lgkmcnt(4) \n \
"
:: );
}
else if (cnt == 8) {
asm volatile("\
s_waitcnt lgkmcnt(8) \n \
"
:: );
}
else if (cnt == 12) {
asm volatile("\
s_waitcnt lgkmcnt(12) \n \
"
:: );
}
else {
asm volatile("\
s_waitcnt lgkmcnt(16) \n \
"
:: );
}
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment