Commit 22adc4db authored by Astha Rai's avatar Astha Rai
Browse files

resolved comments from review: put calls to reinterpret_cast for size_t in header guards

parent 3d711481
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
//#include "ck/host_utility/io.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -1021,14 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -1021,14 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes); static_assert(bytes_per_thread == dword_bytes);
#ifndef CK_CODE_GEN_RTC
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
#else
const uint32_t* global_ptr = const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<size_t>(global_base_ptr)); reinterpret_cast<uint32_t*>(reinterpret_cast<size_t>(global_base_ptr));
#endif
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset; T* lds_ptr = lds_base_ptr + lds_offset;
#ifndef CK_CODE_GEN_RTC
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
#else
auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<size_t>(lds_ptr))); auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<size_t>(lds_ptr)));
#endif
asm volatile("s_mov_b32 m0, %0; \n\t" asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes), "v"(global_offset_bytes),
...@@ -1037,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -1037,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#else #else
// LDS pointer must be attributed with the LDS address space. // LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr = __attribute__((address_space(3))) uint32_t* lds_ptr =
#ifndef CK_CODE_GEN_RTC
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
#else
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<size_t>(lds_base_ptr + lds_offset)); reinterpret_cast<size_t>(lds_base_ptr + lds_offset));
#endif
llvm_amdgcn_raw_buffer_load_lds( llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
......
...@@ -825,7 +825,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -825,7 +825,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if constexpr(stochastic_rounding) if constexpr(stochastic_rounding)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f); #ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
} }
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>( return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng); f, rng);
...@@ -841,7 +845,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -841,7 +845,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if constexpr(stochastic_rounding) if constexpr(stochastic_rounding)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f); rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
} }
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
......
...@@ -178,7 +178,11 @@ template <> ...@@ -178,7 +178,11 @@ template <>
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x); #ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -219,7 +223,11 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x) ...@@ -219,7 +223,11 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::cast_to_f8<half_t, return utils::cast_to_f8<half_t,
f8_fnuz_t, f8_fnuz_t,
negative_zero_nan, negative_zero_nan,
...@@ -233,7 +241,11 @@ template <> ...@@ -233,7 +241,11 @@ template <>
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x); #ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -276,7 +288,11 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x ...@@ -276,7 +288,11 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::cast_to_f8<half_t, return utils::cast_to_f8<half_t,
bf8_fnuz_t, bf8_fnuz_t,
negative_zero_nan, negative_zero_nan,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment