Commit ead87d72 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add bf8 to buffer addressing

parent 78bfffb2
...@@ -1127,8 +1127,16 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,8 +1127,16 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
...@@ -1139,12 +1147,20 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1139,12 +1147,20 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#endif #endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
#else #else
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
...@@ -1156,7 +1172,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1156,7 +1172,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
#endif #endif
...@@ -1216,29 +1232,50 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1216,29 +1232,50 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{ {
auto tmp = auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data); src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>( amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
} }
else else
{ {
#endif #endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(src_thread_data,
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); dst_wave_buffer_resource,
#if defined CK_ENABLE_FP8 dst_addr_shift +
dst_thread_addr_offset,
0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{ {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>( auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data); src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>( amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0); tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
...@@ -1248,7 +1285,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1248,7 +1285,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif #endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment