"python/ait_impl/samples/gemm_rrr_3.cpp" did not exist on "516bbdcbcc1fac0e8b149e971c38481549ee56fe"
Commit 0fff2a66 authored by Jing Zhang's avatar Jing Zhang
Browse files

add guards

parent 35e61bf6
...@@ -566,9 +566,12 @@ template <typename T, index_t N> ...@@ -566,9 +566,12 @@ template <typename T, index_t N>
__device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data, __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
T* addr) T* addr)
{ {
static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, half_t>::value) if constexpr(is_same<T, half_t>::value)
{ {
static_assert(N % 2 == 0, "");
vector_type<half_t, N> tmp{src_thread_data}; vector_type<half_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) { static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i, __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
...@@ -577,7 +580,6 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -577,7 +580,6 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
} }
else if constexpr(is_same<T, bhalf_t>::value) else if constexpr(is_same<T, bhalf_t>::value)
{ {
static_assert(N % 2 == 0, "");
vector_type<bhalf_t, N> tmp{src_thread_data}; vector_type<bhalf_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) { static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i, __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
...@@ -935,7 +937,6 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -935,7 +937,6 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
{ {
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
amd_global_atomic_add_impl<scalar_t, vector_size>( amd_global_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, p_dst_wave + dst_thread_element_offset); src_thread_data, p_dst_wave + dst_thread_element_offset);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment