Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b7ab5e92
You need to sign in or sign up before continuing.
Commit
b7ab5e92
authored
Nov 23, 2023
by
Umang Yadav
Browse files
merge latest develop into migraphx
parent
3c4fb1dd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
93 additions
and
19 deletions
+93
-19
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+4
-4
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+8
-7
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+1
-0
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+70
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+7
-5
No files found.
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
b7ab5e92
...
...
@@ -209,10 +209,10 @@ struct Bilinear
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int32_t
,
int8_t
>
(
int8_t
&
y
,
const
int32_t
&
x0
,
const
int8_t
&
x1
)
const
{
y
=
type_convert
<
std
::
int8_t
>
(
x0
+
ck
::
type_convert
<
std
::
int32_t
>
(
x1
));
y
=
type_convert
<
int8_t
>
(
x0
+
ck
::
type_convert
<
int32_t
>
(
x1
));
};
float
alpha_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
b7ab5e92
...
...
@@ -411,9 +411,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
DsStride
)
MakeDsGridDescriptor_M_N
(
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -877,7 +877,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>
StrideDs
,
const
ck
::
A
rray
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
include/ck/utility/is_detected.hpp
View file @
b7ab5e92
...
...
@@ -2,21 +2,22 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
value_t
=
ck
::
false_type
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
struct
detector
<
Default
,
ck
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
value_t
=
ck
::
true_type
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
...
...
@@ -32,12 +33,12 @@ template <template <class...> class Op, class... Args>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
using
is_pack2_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
using
is_pack4_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
using
is_pack8_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
include/ck/utility/math_v2.hpp
View file @
b7ab5e92
...
...
@@ -184,6 +184,7 @@ inline __host__ double expm1<double>(double x)
{
return
std
::
expm1
(
x
);
}
#endif // __HIPCC_RTC__
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
...
...
include/ck/utility/type.hpp
View file @
b7ab5e92
...
...
@@ -31,12 +31,24 @@ namespace ck {
}
CK_BUILTIN_TYPE_TRAIT1
(
is_class
);
CK_BUILTIN_TYPE_TRAIT1
(
is_const
);
CK_BUILTIN_TYPE_TRAIT1
(
is_pointer
);
CK_BUILTIN_TYPE_TRAIT1
(
is_reference
);
CK_BUILTIN_TYPE_TRAIT1
(
is_trivially_copyable
);
CK_BUILTIN_TYPE_TRAIT1
(
is_unsigned
);
CK_BUILTIN_TYPE_TRAIT2
(
is_base_of
);
template
<
class
T
>
struct
remove_const
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_cv
{
...
...
@@ -106,19 +118,71 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
typename
...
Ts
>
struct
make_void
{
typedef
void
type
;
};
template
<
typename
...
Ts
>
using
void_t
=
typename
make_void
<
Ts
...
>::
type
;
// namespace detail {
// template <class T>
// struct type_identity
// {
// using type = T;
// };
// template <class T> // Note that `cv void&` is a substitution failure
// auto try_add_lvalue_reference(int) -> type_identity<T&>;
// template <class T> // Handle T = cv void case
// auto try_add_lvalue_reference(...) -> type_identity<T>;
// template <class T>
// auto try_add_rvalue_reference(int) -> type_identity<T&&>;
// template <class T>
// auto try_add_rvalue_reference(...) -> type_identity<T>;
// } // namespace detail
// template <class T>
// struct add_lvalue_reference : decltype(detail::try_add_lvalue_reference<T>(0))
// {
// };
// template <class T>
// struct add_rvalue_reference : decltype(detail::try_add_rvalue_reference<T>(0))
// {
// };
// template <class T>
// typename add_rvalue_reference<T>::type declval();
template
<
class
T
,
class
U
=
T
&&
>
U
private_declval
(
int
);
template
<
class
T
>
T
private_declval
(
long
);
template
<
class
T
>
auto
declval
()
noexcept
->
decltype
(
private_declval
<
T
>
(
0
));
#else
#include <utility>
#include <type_traits>
using
std
::
declval
;
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_class
;
using
std
::
is_const
;
using
std
::
is_pointer
;
using
std
::
is_reference
;
using
std
::
is_trivially_copyable
;
using
std
::
is_unsigned
;
using
std
::
remove_const
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
void_t
;
#endif
template
<
typename
X
,
typename
Y
>
...
...
@@ -140,9 +204,15 @@ inline constexpr bool is_same_v = is_same<X, Y>::value;
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_unsigned_v
=
is_unsigned
<
T
>::
value
;
template
<
class
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
typename
T
>
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
...
...
include/ck/utility/type_convert.hpp
View file @
b7ab5e92
...
...
@@ -4,6 +4,8 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
...
...
@@ -23,7 +25,7 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
ck
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
...
...
@@ -341,7 +343,7 @@ template <>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uint
ptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uint
64
_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
...
...
@@ -376,7 +378,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uint
ptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uint
64
_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
...
...
@@ -388,7 +390,7 @@ template <>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uint
ptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uint
64
_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
...
...
@@ -424,7 +426,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uint
ptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uint
64
_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment