Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b7ab5e92
Commit
b7ab5e92
authored
Nov 23, 2023
by
Umang Yadav
Browse files
merge latest develop into migraphx
parent
3c4fb1dd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
93 additions
and
19 deletions
+93
-19
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+4
-4
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+8
-7
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+1
-0
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+70
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+7
-5
No files found.
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
b7ab5e92
...
...
@@ -209,10 +209,10 @@ struct Bilinear
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int32_t
,
int8_t
>
(
int8_t
&
y
,
const
int32_t
&
x0
,
const
int8_t
&
x1
)
const
{
y
=
type_convert
<
std
::
int8_t
>
(
x0
+
ck
::
type_convert
<
std
::
int32_t
>
(
x1
));
y
=
type_convert
<
int8_t
>
(
x0
+
ck
::
type_convert
<
int32_t
>
(
x1
));
};
float
alpha_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
b7ab5e92
...
...
@@ -411,9 +411,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
DsStride
)
MakeDsGridDescriptor_M_N
(
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -877,7 +877,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>
StrideDs
,
const
ck
::
A
rray
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
include/ck/utility/is_detected.hpp
View file @
b7ab5e92
...
...
@@ -2,21 +2,22 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
value_t
=
ck
::
false_type
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
struct
detector
<
Default
,
ck
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
value_t
=
ck
::
true_type
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
...
...
@@ -32,12 +33,12 @@ template <template <class...> class Op, class... Args>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
using
is_pack2_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
using
is_pack4_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
using
is_pack8_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
include/ck/utility/math_v2.hpp
View file @
b7ab5e92
...
...
@@ -184,6 +184,7 @@ inline __host__ double expm1<double>(double x)
{
return
std
::
expm1
(
x
);
}
#endif // __HIPCC_RTC__
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
...
...
include/ck/utility/type.hpp
View file @
b7ab5e92
...
...
@@ -31,12 +31,24 @@ namespace ck {
}
CK_BUILTIN_TYPE_TRAIT1
(
is_class
);
CK_BUILTIN_TYPE_TRAIT1
(
is_const
);
CK_BUILTIN_TYPE_TRAIT1
(
is_pointer
);
CK_BUILTIN_TYPE_TRAIT1
(
is_reference
);
CK_BUILTIN_TYPE_TRAIT1
(
is_trivially_copyable
);
CK_BUILTIN_TYPE_TRAIT1
(
is_unsigned
);
CK_BUILTIN_TYPE_TRAIT2
(
is_base_of
);
template
<
class
T
>
struct
remove_const
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_cv
{
...
...
@@ -106,19 +118,71 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
typename
...
Ts
>
struct
make_void
{
typedef
void
type
;
};
template
<
typename
...
Ts
>
using
void_t
=
typename
make_void
<
Ts
...
>::
type
;
// namespace detail {
// template <class T>
// struct type_identity
// {
// using type = T;
// };
// template <class T> // Note that `cv void&` is a substitution failure
// auto try_add_lvalue_reference(int) -> type_identity<T&>;
// template <class T> // Handle T = cv void case
// auto try_add_lvalue_reference(...) -> type_identity<T>;
// template <class T>
// auto try_add_rvalue_reference(int) -> type_identity<T&&>;
// template <class T>
// auto try_add_rvalue_reference(...) -> type_identity<T>;
// } // namespace detail
// template <class T>
// struct add_lvalue_reference : decltype(detail::try_add_lvalue_reference<T>(0))
// {
// };
// template <class T>
// struct add_rvalue_reference : decltype(detail::try_add_rvalue_reference<T>(0))
// {
// };
// template <class T>
// typename add_rvalue_reference<T>::type declval();
template
<
class
T
,
class
U
=
T
&&
>
U
private_declval
(
int
);
template
<
class
T
>
T
private_declval
(
long
);
template
<
class
T
>
auto
declval
()
noexcept
->
decltype
(
private_declval
<
T
>
(
0
));
#else
#include <utility>
#include <type_traits>
using
std
::
declval
;
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_class
;
using
std
::
is_const
;
using
std
::
is_pointer
;
using
std
::
is_reference
;
using
std
::
is_trivially_copyable
;
using
std
::
is_unsigned
;
using
std
::
remove_const
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
void_t
;
#endif
template
<
typename
X
,
typename
Y
>
...
...
@@ -140,9 +204,15 @@ inline constexpr bool is_same_v = is_same<X, Y>::value;
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_unsigned_v
=
is_unsigned
<
T
>::
value
;
template
<
class
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
typename
T
>
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
...
...
include/ck/utility/type_convert.hpp
View file @
b7ab5e92
...
...
@@ -4,6 +4,8 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
...
...
@@ -23,7 +25,7 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
ck
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
...
...
@@ -341,7 +343,7 @@ template <>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uint
ptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uint
64
_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
...
...
@@ -376,7 +378,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uint
ptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uint
64
_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
...
...
@@ -388,7 +390,7 @@ template <>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uint
ptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uint
64
_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
...
...
@@ -424,7 +426,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uint
ptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uint
64
_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment