Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
22adc4db
Commit
22adc4db
authored
Jan 23, 2025
by
Astha Rai
Browse files
resolved comments from review: put calls to reinterpret_cast for size_t in header guards
parent
3d711481
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
5 deletions
+43
-5
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...gen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+0
-1
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+16
-1
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+9
-1
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+18
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
22adc4db
...
@@ -27,7 +27,6 @@
...
@@ -27,7 +27,6 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
//#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
22adc4db
...
@@ -1021,14 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
...
@@ -1021,14 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
constexpr
auto
bytes_per_thread
=
sizeof
(
T
)
*
NumElemsPerThread
;
constexpr
auto
bytes_per_thread
=
sizeof
(
T
)
*
NumElemsPerThread
;
static_assert
(
bytes_per_thread
==
dword_bytes
);
static_assert
(
bytes_per_thread
==
dword_bytes
);
#ifndef CK_CODE_GEN_RTC
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
global_base_ptr
));
#else
const
uint32_t
*
global_ptr
=
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
global_base_ptr
));
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
global_base_ptr
));
#endif
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
#ifndef CK_CODE_GEN_RTC
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
lds_ptr
)));
#else
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
size_t
>
(
lds_ptr
)));
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
size_t
>
(
lds_ptr
)));
#endif
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"v"
(
global_offset_bytes
),
...
@@ -1037,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
...
@@ -1037,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#else
#else
// LDS pointer must be attributed with the LDS address space.
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
#ifndef CK_CODE_GEN_RTC
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
lds_base_ptr
+
lds_offset
));
#else
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
lds_base_ptr
+
lds_offset
));
reinterpret_cast
<
size_t
>
(
lds_base_ptr
+
lds_offset
));
#endif
llvm_amdgcn_raw_buffer_load_lds
(
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
22adc4db
...
@@ -825,7 +825,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
...
@@ -825,7 +825,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if
constexpr
(
stochastic_rounding
)
if
constexpr
(
stochastic_rounding
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#ifndef CK_CODE_GEN_RTC
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#else
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#endif
}
}
return
cast_to_f8_from_f32
<
interp
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
return
cast_to_f8_from_f32
<
interp
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
f
,
rng
);
...
@@ -841,7 +845,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
...
@@ -841,7 +845,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if
constexpr
(
stochastic_rounding
)
if
constexpr
(
stochastic_rounding
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#else
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#endif
}
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
...
...
include/ck/utility/type_convert.hpp
View file @
22adc4db
...
@@ -178,7 +178,11 @@ template <>
...
@@ -178,7 +178,11 @@ template <>
inline
__host__
__device__
f8_fnuz_t
f8_convert_sr
<
f8_fnuz_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_fnuz_t
f8_convert_sr
<
f8_fnuz_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -219,7 +223,11 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
...
@@ -219,7 +223,11 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
f8_fnuz_t
,
negative_zero_nan
,
negative_zero_nan
,
...
@@ -233,7 +241,11 @@ template <>
...
@@ -233,7 +241,11 @@ template <>
inline
__host__
__device__
bf8_fnuz_t
f8_convert_sr
<
bf8_fnuz_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_fnuz_t
f8_convert_sr
<
bf8_fnuz_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -276,7 +288,11 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
...
@@ -276,7 +288,11 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
bf8_fnuz_t
,
negative_zero_nan
,
negative_zero_nan
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment