Commit ead87d72 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add bf8 to buffer addressing

parent 78bfffb2
...@@ -1127,37 +1127,53 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,37 +1127,53 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
return bit_cast<vector_t>(tmp);
}
else
{
#endif #endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( #if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); if constexpr(is_same<scalar_t, f8_t>::value)
#if defined CK_ENABLE_FP8 #endif
} #if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
return bit_cast<vector_t>(tmp);
}
else
{
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif #endif
#else #else
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
}
else
{
#endif #endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( #if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(is_same<scalar_t, f8_t>::value)
return src_thread_element_valid ? tmp : vector_t(0); #endif
#if defined CK_ENABLE_FP8 #if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
} if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
}
else
{
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif #endif
#endif #endif
} }
...@@ -1216,40 +1232,61 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1216,40 +1232,61 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
{ if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
else
{
#endif #endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( #if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); if constexpr(is_same<scalar_t, f8_t>::value)
#if defined CK_ENABLE_FP8 #endif
} #if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(src_thread_data,
dst_wave_buffer_resource,
dst_addr_shift +
dst_thread_addr_offset,
0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif #endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
{
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
else
{
#endif #endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( #if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); if constexpr(is_same<scalar_t, f8_t>::value)
#if defined CK_ENABLE_FP8 #endif
} #if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif #endif
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment