Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f74cf520
"next_docs/vscode:/vscode.git/clone" did not exist on "cb57e84cd3af1749a258881813183b845a01e6f1"
Commit
f74cf520
authored
Aug 31, 2021
by
ltqin
Browse files
Merge branch 'develop' into backward_weight_v4r4r2_xdlops
parents
7bc4254d
10bb8110
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
465 additions
and
235 deletions
+465
-235
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+1
-2
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+3
-4
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+5
-7
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
...include/tensor_operation/threadwise_contraction_dlops.hpp
+18
-24
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+9
-12
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
.../include/tensor_operation/threadwise_tensor_slice_set.hpp
+2
-2
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+39
-38
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
.../tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
+16
-19
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+271
-55
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/array.hpp
+1
-1
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+2
-2
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+11
-0
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+80
-58
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+1
-1
composable_kernel/include/utility/tuple_helper.hpp
composable_kernel/include/utility/tuple_helper.hpp
+1
-3
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+3
-0
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
...ution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
+2
-7
No files found.
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
f74cf520
...
...
@@ -189,8 +189,7 @@ struct TensorAdaptor
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cv_t
<
remove_reference_t
<
decltype
(
Transforms
{}[
i
])
>>::
IsKnownAtCompileTime
();
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
;
...
...
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
f74cf520
...
...
@@ -185,8 +185,7 @@ struct TensorDescriptor
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cv_t
<
remove_reference_t
<
decltype
(
Transforms
{}[
i
])
>>::
IsKnownAtCompileTime
();
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
&&
...
...
@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
template
<
typename
TensorDesc
>
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cv
_t
<
remove_reference
_t
<
TensorDesc
>
>
::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv
ref
_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
using
TensorCoordinateStep_t
=
decltype
(
make_tensor_coordinate_step
(
TensorDesc
{},
MultiIndex
<
remove_cv
_t
<
remove_reference
_t
<
TensorDesc
>
>
::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv
ref
_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
f74cf520
...
...
@@ -110,12 +110,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
View file @
f74cf520
...
...
@@ -55,18 +55,15 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -157,18 +154,15 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
f74cf520
...
...
@@ -41,18 +41,15 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
View file @
f74cf520
...
...
@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1
static_assert
(
Buffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv
_t
<
remove_reference
_t
<
OriginIdx
>>
>
::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv
ref
_t
<
OriginIdx
>>::
value
,
"wrong! OriginIdx need to be known at compile-time"
);
// Desc is known at compile-time
constexpr
auto
desc
=
remove_cv
_t
<
remove_reference
_t
<
Desc
>
>
{};
constexpr
auto
desc
=
remove_cv
ref
_t
<
Desc
>
{};
// OriginIdx is known at compile-time
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
f74cf520
...
...
@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcSliceOriginIdx
>>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
,
"wrong! SrcSliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
(),
"wrong! SrcBuffer need to be StaticBuffer"
);
// static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
// remove_cv_t<remove_reference_t<SrcData>>>::value,
//"wrong! SrcBuffer data type is wrong");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -208,10 +203,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
Set
)
{
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
}
else
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
AtomicAdd
)
{
dst_buf
.
template
AtomicAdd
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
}
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
...
...
@@ -411,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! DstDesc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstSliceOriginIdx
>>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! DstSliceOrigin need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
&&
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
&&
"wrong! inconsistent type"
);
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
constexpr
auto
dst_slice_origin_idx
=
DstSliceOriginIdx
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -732,8 +736,8 @@ struct ThreadwiseTensorSliceTransfer_v3
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
SrcData
>>
>
::
value
,
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_
cv
ref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -889,8 +893,8 @@ struct ThreadwiseTensorSliceTransfer_v3
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
,
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -1305,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstOriginIdx
>>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
View file @
f74cf520
...
...
@@ -80,8 +80,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
SrcData
>>
>
::
value
,
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_
cv
ref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
// tensor descriptor for src_vector
...
...
@@ -248,8 +248,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
,
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
// tensor descriptor for dst_vector
...
...
@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstOriginIdx
>>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
...
...
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
f74cf520
...
...
@@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4f32"
);
// atomic add
// int
__device__
int32_t
llvm_amdgcn_raw_buffer_atomic_add_i32
(
int32_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.add.i32"
);
// float
__device__
float
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
float
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
...
...
@@ -209,13 +225,49 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
index_t
src_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
)),
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
// use fp32 load to mimic fp64 load
if
constexpr
(
N
==
1
)
{
const
float2_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
double
>
(
tmp
);
}
else
if
constexpr
(
N
==
2
)
{
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
double2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
const
float4_t
f32_0
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
const
float4_t
f32_1
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
0
);
vector_type
<
double
,
4
>
tmp
;
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
as_type
<
double2_t
>
(
f32_0
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
1
>
{})
=
as_type
<
double2_t
>
(
f32_1
);
return
tmp
.
AsType
<
double4_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
...
...
@@ -267,25 +319,11 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
}
else
if
constexpr
(
N
==
8
)
{
#if 0
vector_type<half_t, 8> tmp;
tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<half4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t),
0);
return tmp.AsType<half8_t>()(Number<0>{});
#else
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
half8_t
>
(
tmp
);
#endif
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
...
...
@@ -417,13 +455,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
// use fp32 store to mimic fp64 store
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
as_type
<
float2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
as_type
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
...
...
@@ -450,6 +509,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
half_t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
half_t
),
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
...
...
@@ -536,11 +638,23 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_add_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
((
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_
s
to
re
_fp
16
(
src_thread_data
,
llvm_amdgcn_raw_buffer_
a
to
mic_add
_fp
32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -548,41 +662,108 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
vector_type
<
float
,
2
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
tmp
.
AsType
<
float
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
tmp
.
AsType
<
float
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
float
),
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
vector_type
<
float
,
4
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
tmp
.
AsType
<
float
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
tmp
.
AsType
<
float
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
float
),
0
);
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
tmp
.
AsType
<
float
>
()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
2
*
sizeof
(
float
),
0
);
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
tmp
.
AsType
<
float
>
()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
3
*
sizeof
(
float
),
0
);
}
else
if
constexpr
(
N
==
8
)
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
vector_type
<
half_t
,
8
>
tmp
{
src_thread_data
};
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_atomic_add_i32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
vector_type
<
int32_t
,
2
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_
s
to
re_fp16x4
(
tmp
.
AsType
<
half4
_t
>
()[
Number
<
0
>
{}],
llvm_amdgcn_raw_buffer_
a
to
mic_add_i32
(
tmp
.
AsType
<
int32
_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_
s
to
re_fp16x4
(
tmp
.
AsType
<
half4
_t
>
()[
Number
<
1
>
{}],
llvm_amdgcn_raw_buffer_
a
to
mic_add_i32
(
tmp
.
AsType
<
int32
_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
half_t
),
dst_wave_addr_offset
+
sizeof
(
int32_t
),
0
);
}
else
if
constexpr
(
N
==
4
)
{
vector_type
<
int32_t
,
4
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_add_i32
(
tmp
.
AsType
<
int32_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_add_i32
(
tmp
.
AsType
<
int32_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
),
0
);
llvm_amdgcn_raw_buffer_atomic_add_i32
(
tmp
.
AsType
<
int32_t
>
()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
2
*
sizeof
(
int32_t
),
0
);
llvm_amdgcn_raw_buffer_atomic_add_i32
(
tmp
.
AsType
<
int32_t
>
()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
3
*
sizeof
(
int32_t
),
0
);
}
}
}
// buffer_load requires:
// 1) p_src_wave must
be in
global memory space
// 1) p_src_wave must
point to
global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
...
...
@@ -616,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
}
// buffer_load requires:
// 1) p_src_wave must
be in
global memory space
// 1) p_src_wave must
point to
global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
...
...
@@ -644,8 +825,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
}
// buffer_store requires:
// 1) p_dst_wave must
be
global memory
// 2) p_dst_wave t
o
be a wavewise pointer.
// 1) p_dst_wave must
point to
global memory
// 2) p_dst_wave
mus
t be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_store
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
...
...
@@ -677,5 +858,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif
}
// buffer_atomic_add requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_add
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
);
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x7fffffff
;
amd_buffer_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
}
// namespace ck
#endif
composable_kernel/include/utility/array.hpp
View file @
f74cf520
...
...
@@ -48,7 +48,7 @@ struct Array<TData, 0>
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
using
data_type
=
remove_cv
_t
<
remove_reference
_t
<
X
>
>
;
using
data_type
=
remove_cv
ref
_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...}};
}
...
...
composable_kernel/include/utility/config.hpp
View file @
f74cf520
...
...
@@ -85,8 +85,8 @@
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_
ADD_
OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_
ADD_
OOB_CHECK_OFFSET_TRICK 1
#endif
// pass tensor descriptor by value or void*
...
...
composable_kernel/include/utility/data_type.hpp
View file @
f74cf520
...
...
@@ -73,6 +73,13 @@ struct scalar_type<vector_type<T, N>>
};
//
template
<
>
struct
scalar_type
<
double
>
{
using
type
=
double
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
float
>
{
...
...
@@ -864,6 +871,10 @@ struct vector_type<T, 256>
}
};
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
// fp32
using
float2_t
=
typename
vector_type
<
float
,
2
>::
type
;
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
View file @
f74cf520
...
...
@@ -39,18 +39,15 @@ struct DynamicBuffer
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
...
...
@@ -67,14 +64,13 @@ struct DynamicBuffer
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_return_zero
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
return
amd_buffer_load_invalid_element_return_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
}
...
...
@@ -94,18 +90,15 @@ struct DynamicBuffer
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
...
...
@@ -115,7 +108,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
t_per_x
>
(
amd_buffer_store
<
remove_cv
ref
_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
#else
if
(
is_valid_element
)
...
...
@@ -136,70 +129,65 @@ struct DynamicBuffer
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if
constexpr
(
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
,
int8_t
>::
value
)
{
static_assert
(
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x2_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
),
if
constexpr
(
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
)
{
static_assert
((
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add "
"implementation"
);
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8_t
>::
value
)
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int8_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8x2_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8x2_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int16_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8x4_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
...
...
@@ -223,6 +211,35 @@ struct DynamicBuffer
}
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
AtomicAdd
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
,
"only support global mem"
);
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
#else
if
(
is_valid_element
)
{
atomicAdd
(
&
p_data_
[
i
],
x
);
}
#endif
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
...
...
@@ -234,9 +251,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
true
>
{
p
,
element_space_size
};
}
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
,
typename
X
,
typename
enable_if
<
is_same
<
remove_cvref_t
<
T
>,
remove_cvref_t
<
X
>>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
,
T
invalid_element_value
)
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
,
X
invalid_element_value
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
false
>
{
p
,
element_space_size
,
invalid_element_value
};
...
...
composable_kernel/include/utility/tuple.hpp
View file @
f74cf520
...
...
@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cv
_t
<
remove_reference
_t
<
Xs
>
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cv
ref
_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
}
}
// namespace ck
...
...
composable_kernel/include/utility/tuple_helper.hpp
View file @
f74cf520
...
...
@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>>
return
container_reduce
(
Tuple
<
Ts
...
>
{},
[](
auto
x
,
bool
r
)
{
return
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
decltype
(
x
)
>>>::
value
&
r
;
return
is_known_at_compile_time
<
remove_cvref_t
<
decltype
(
x
)
>>::
value
&
r
;
},
true
);
}
...
...
composable_kernel/include/utility/type.hpp
View file @
f74cf520
...
...
@@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference<T>::type;
template
<
typename
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
...
...
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
View file @
f74cf520
...
...
@@ -374,13 +374,8 @@ extern "C" __global__ void
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1
{},
CGridBlockCluster_BlockId_To_GM10_GN10
{}));
const
auto
desc_tuple
=
*
reinterpret_cast
<
const
DescTuple
*>
(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// TODO: how to cast?
(
const
void
*
)
p_desc_tuple
#pragma clang diagnostic pop
);
const
auto
desc_tuple
=
*
reinterpret_cast
<
const
DescTuple
*>
(
cast_pointer_to_generic_address_space
(
p_desc_tuple
));
const
auto
a_grid_desc_gk0_gm0_gm10_gm11_gk1
=
desc_tuple
[
I0
];
const
auto
b_grid_desc_gk0_gn0_gn10_gn11_gk1
=
desc_tuple
[
I1
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment