Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7d9969ab
Commit
7d9969ab
authored
Sep 30, 2024
by
Astha Rai
Browse files
resolved conflict errors in a few utility files
parent
3624dc2a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
44 additions
and
15 deletions
+44
-15
include/ck/ck.hpp
include/ck/ck.hpp
+1
-2
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+4
-5
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1
-1
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+4
-0
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+27
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+7
-7
No files found.
include/ck/ck.hpp
View file @
7d9969ab
...
@@ -10,12 +10,11 @@
...
@@ -10,12 +10,11 @@
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
#endif
#endif
#endif
// environment variable to enable logging:
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL
(
CK_LOGGING
)
CK_DECLARE_ENV_VAR_BOOL
(
CK_LOGGING
)
#endif
// to do: add various levels of logging with CK_LOG_LEVEL
// to do: add various levels of logging with CK_LOG_LEVEL
#define CK_TIME_KERNEL 1
#define CK_TIME_KERNEL 1
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
7d9969ab
...
@@ -1019,14 +1019,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
...
@@ -1019,14 +1019,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
static_assert
(
bytes_per_thread
==
dword_bytes
);
static_assert
(
bytes_per_thread
==
dword_bytes
);
const
uint32_t
*
global_ptr
=
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uintptr
_t
>
(
global_base_ptr
));
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
size
_t
>
(
global_base_ptr
));
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
auto
const
lds_ptr_sgpr
=
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
size_t
>
(
lds_ptr
)));
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
lds_ptr
)));
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"v"
(
global_offset_bytes
),
...
@@ -1036,7 +1035,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
...
@@ -1036,7 +1035,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
// LDS pointer must be attributed with the LDS address space.
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
uintptr
_t
>
(
lds_base_ptr
+
lds_offset
));
reinterpret_cast
<
size
_t
>
(
lds_base_ptr
+
lds_offset
));
llvm_amdgcn_raw_buffer_load_lds
(
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
...
...
include/ck/utility/data_type.hpp
View file @
7d9969ab
...
@@ -13,7 +13,7 @@ using float_t = float;
...
@@ -13,7 +13,7 @@ using float_t = float;
#endif
#endif
namespace
ck
{
namespace
ck
{
#ifdef CK_CODE_GEN_RTC
_
#ifdef CK_CODE_GEN_RTC
using
byte
=
unsigned
char
;
using
byte
=
unsigned
char
;
#else
#else
using
std
::
byte
;
using
std
::
byte
;
...
...
include/ck/utility/magic_division.hpp
View file @
7d9969ab
...
@@ -9,6 +9,10 @@
...
@@ -9,6 +9,10 @@
#include "type.hpp"
#include "type.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace
ck
{
namespace
ck
{
// magic number division
// magic number division
...
...
include/ck/utility/type.hpp
View file @
7d9969ab
...
@@ -94,6 +94,16 @@ struct remove_pointer<T* const volatile>
...
@@ -94,6 +94,16 @@ struct remove_pointer<T* const volatile>
{
{
typedef
T
type
;
typedef
T
type
;
};
};
template
<
class
T
>
struct
remove_const
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
template
<
typename
T
>
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&
t_
)
noexcept
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&
t_
)
noexcept
...
@@ -116,6 +126,7 @@ using std::is_pointer;
...
@@ -116,6 +126,7 @@ using std::is_pointer;
using
std
::
is_reference
;
using
std
::
is_reference
;
using
std
::
is_trivially_copyable
;
using
std
::
is_trivially_copyable
;
using
std
::
is_unsigned
;
using
std
::
is_unsigned
;
using
std
::
remove_const
;
using
std
::
remove_cv
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
remove_reference
;
...
@@ -131,12 +142,25 @@ struct is_same<X, X> : public integral_constant<bool, true>
...
@@ -131,12 +142,25 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
{
};
};
template
<
typename
X
>
struct
is_const
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
typename
X
>
struct
is_const
<
const
X
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
X
>
inline
constexpr
bool
is_const_v
=
is_const
<
X
>::
value
;
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
...
@@ -158,6 +182,9 @@ using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
...
@@ -158,6 +182,9 @@ using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
template
<
typename
T
>
template
<
typename
T
>
using
remove_pointer_t
=
typename
remove_pointer
<
T
>::
type
;
using
remove_pointer_t
=
typename
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
is_pointer
<
T
>::
value
;
inline
constexpr
bool
is_pointer_v
=
is_pointer
<
T
>::
value
;
...
...
include/ck/utility/type_convert.hpp
View file @
7d9969ab
...
@@ -17,7 +17,7 @@ namespace ck {
...
@@ -17,7 +17,7 @@ namespace ck {
// Convert X to Y, both X and Y are non-const data types.
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
template
<
typename
Y
,
typename
X
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
...
@@ -28,13 +28,13 @@ __host__ __device__ constexpr Y type_convert(X x)
...
@@ -28,13 +28,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
template
<
typename
Y
,
typename
X
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
ck
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstY
=
ck
::
remove_const_t
<
Y
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
using
NonConstX
=
ck
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
}
}
...
@@ -501,11 +501,11 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
...
@@ -501,11 +501,11 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif
#endif
}
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
template
<
typename
Y
,
typename
X
,
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
const
std
::
array
<
X
,
NumElems
>&
x
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
...
@@ -514,7 +514,7 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
...
@@ -514,7 +514,7 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment