Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
709f13a6
Commit
709f13a6
authored
Jun 04, 2019
by
Chao Liu
Browse files
use more constexpr
parent
498e71b0
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
292 additions
and
165 deletions
+292
-165
driver/driver.hip.cpp
driver/driver.hip.cpp
+15
-3
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+52
-3
src/include/ConstantMergedTensorDescriptor.hip.hpp
src/include/ConstantMergedTensorDescriptor.hip.hpp
+3
-16
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+36
-18
src/include/blockwise_generic_tensor_slice_op.hip.hpp
src/include/blockwise_generic_tensor_slice_op.hip.hpp
+13
-0
src/include/common.hip.hpp
src/include/common.hip.hpp
+1
-0
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+9
-45
src/include/functional2.hip.hpp
src/include/functional2.hip.hpp
+47
-73
src/include/functional3.hip.hpp
src/include/functional3.hip.hpp
+109
-0
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
...implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
+1
-1
src/include/threadwise_generic_tensor_slice_op.hip.hpp
src/include/threadwise_generic_tensor_slice_op.hip.hpp
+6
-6
No files found.
driver/driver.hip.cpp
View file @
709f13a6
...
@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
...
@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 3x3 filter, 28x28 image
// 3x3 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -455,7 +455,7 @@ int main(int argc, char* argv[])
...
@@ -455,7 +455,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 1x1 filter, 28x28 image
// 1x1 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
C
=
512
;
...
@@ -549,12 +549,24 @@ int main(int argc, char* argv[])
...
@@ -549,12 +549,24 @@ int main(int argc, char* argv[])
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 1x1 filter, 7x7 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 1x1 filter, 73x73 image
// 1x1 filter, 73x73 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
;
...
...
src/include/Array.hip.hpp
View file @
709f13a6
#pragma once
#pragma once
#include "Sequence.hip.hpp"
#include "Sequence.hip.hpp"
#include "functional.hip.hpp"
#include "functional
2
.hip.hpp"
template
<
class
TData
,
index_t
NSize
>
template
<
class
TData
,
index_t
NSize
>
struct
Array
struct
Array
...
@@ -25,14 +25,17 @@ struct Array
...
@@ -25,14 +25,17 @@ struct Array
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
TData
Get
(
Number
<
I
>
)
const
__host__
__device__
constexpr
TData
Get
(
Number
<
I
>
)
const
{
{
static_assert
(
I
<
NSize
,
"wrong!"
);
return
mData
[
I
];
return
mData
[
I
];
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
bool
Set
(
Number
<
I
>
,
TData
x
)
__host__
__device__
constexpr
void
Set
(
Number
<
I
>
,
TData
x
)
{
{
static_assert
(
I
<
NSize
,
"wrong!"
);
mData
[
I
]
=
x
;
mData
[
I
]
=
x
;
return
true
;
// for constexpr
}
}
__host__
__device__
constexpr
auto
PushBack
(
TData
x
)
const
__host__
__device__
constexpr
auto
PushBack
(
TData
x
)
const
...
@@ -59,6 +62,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
...
@@ -59,6 +62,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
template
<
class
TData
,
index_t
NSize
>
template
<
class
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
make_zero_array
()
__host__
__device__
constexpr
auto
make_zero_array
()
{
{
#if 0
Array<TData, NSize> a;
Array<TData, NSize> a;
static_for<0, NSize, 1>{}([&](auto I) {
static_for<0, NSize, 1>{}([&](auto I) {
...
@@ -67,6 +71,11 @@ __host__ __device__ constexpr auto make_zero_array()
...
@@ -67,6 +71,11 @@ __host__ __device__ constexpr auto make_zero_array()
});
});
return a;
return a;
#else
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
SeqType
{};
constexpr
auto
zero_array
=
sequence2array
(
zero_sequence
);
return
zero_array
;
#endif
}
}
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
...
@@ -85,6 +94,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
...
@@ -85,6 +94,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return
new_array
;
return
new_array
;
}
}
#if 0
template <class TData, index_t NSize, index_t... IRs>
template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new)
Sequence<IRs...> old2new)
...
@@ -100,6 +110,45 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
...
@@ -100,6 +110,45 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return new_array;
return new_array;
}
}
#else
template
<
class
TData
,
index_t
NSize
,
class
MapOld2New
>
struct
reorder_array_given_old2new_impl
{
const
Array
<
TData
,
NSize
>&
old_array_ref
;
Array
<
TData
,
NSize
>&
new_array_ref
;
__host__
__device__
constexpr
reorder_array_given_old2new_impl
(
const
Array
<
TData
,
NSize
>&
old_array
,
Array
<
TData
,
NSize
>&
new_array
)
:
old_array_ref
(
old_array
),
new_array_ref
(
new_array
)
{
}
template
<
index_t
IOldDim
>
__host__
__device__
constexpr
void
operator
()(
Number
<
IOldDim
>
)
const
{
TData
old_data
=
old_array_ref
.
Get
(
Number
<
IOldDim
>
{});
constexpr
index_t
INewDim
=
MapOld2New
::
Get
(
Number
<
IOldDim
>
{});
new_array_ref
.
Set
(
Number
<
INewDim
>
{},
old_data
);
}
};
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
constexpr
auto
reorder_array_given_old2new
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
old2new
)
{
Array
<
TData
,
NSize
>
new_array
;
static_assert
(
NSize
==
sizeof
...(
IRs
),
"NSize not consistent"
);
static_for
<
0
,
NSize
,
1
>
{}(
reorder_array_given_old2new_impl
<
TData
,
NSize
,
Sequence
<
IRs
...
>>
(
old_array
,
new_array
));
return
new_array
;
}
#endif
template
<
class
TData
,
index_t
NSize
,
class
ExtractSeq
>
template
<
class
TData
,
index_t
NSize
,
class
ExtractSeq
>
__host__
__device__
constexpr
auto
extract_array
(
const
Array
<
TData
,
NSize
>&
old_array
,
ExtractSeq
)
__host__
__device__
constexpr
auto
extract_array
(
const
Array
<
TData
,
NSize
>&
old_array
,
ExtractSeq
)
...
...
src/include/ConstantMergedTensorDescriptor.hip.hpp
View file @
709f13a6
...
@@ -115,15 +115,13 @@ struct ConstantMergedTensorDescriptor
...
@@ -115,15 +115,13 @@ struct ConstantMergedTensorDescriptor
}
}
template
<
index_t
I
>
template
<
index_t
I
>
constexpr
__host__
__device__
bool
operator
()(
Number
<
I
>
)
const
__host__
__device__
constexpr
void
operator
()(
Number
<
I
>
)
const
{
{
constexpr
index_t
idim_original
=
OriginalDimsPartial
::
Get
(
Number
<
I
>
{});
constexpr
index_t
idim_original
=
OriginalDimsPartial
::
Get
(
Number
<
I
>
{});
index_t
itmp
=
original_multi_id_partial_ref
.
Get
(
Number
<
I
>
{});
index_t
itmp
=
original_multi_id_partial_ref
.
Get
(
Number
<
I
>
{});
original_multi_id_ref
.
Set
(
Number
<
idim_original
>
{},
itmp
);
original_multi_id_ref
.
Set
(
Number
<
idim_original
>
{},
itmp
);
return
true
;
}
}
};
};
...
@@ -139,7 +137,7 @@ struct ConstantMergedTensorDescriptor
...
@@ -139,7 +137,7 @@ struct ConstantMergedTensorDescriptor
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
constexpr
__host__
__device__
bool
operator
()(
Number
<
IDim
>
)
const
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim
>
)
const
{
{
constexpr
auto
original_dims_partial
=
constexpr
auto
original_dims_partial
=
std
::
get
<
IDim
>
(
std
::
tuple
<
OriginalDimMergeSeqs
...
>
{});
std
::
get
<
IDim
>
(
std
::
tuple
<
OriginalDimMergeSeqs
...
>
{});
...
@@ -152,11 +150,10 @@ struct ConstantMergedTensorDescriptor
...
@@ -152,11 +150,10 @@ struct ConstantMergedTensorDescriptor
static_for
<
0
,
original_dims_partial
.
GetSize
(),
1
>
{}(
static_for
<
0
,
original_dims_partial
.
GetSize
(),
1
>
{}(
GetOriginalMultiIndexFromMultiIndex_impl1
<
decltype
(
original_dims_partial
)
>
(
GetOriginalMultiIndexFromMultiIndex_impl1
<
decltype
(
original_dims_partial
)
>
(
original_multi_id_partial
,
original_multi_id_ref
));
original_multi_id_partial
,
original_multi_id_ref
));
return
true
;
}
}
};
};
// return type is Array<...>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetOriginalMultiIndexFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
GetOriginalMultiIndexFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
{
{
...
@@ -179,16 +176,6 @@ struct ConstantMergedTensorDescriptor
...
@@ -179,16 +176,6 @@ struct ConstantMergedTensorDescriptor
}
}
#endif
#endif
#if 0
// return type is Sequence<...>
template <index_t... Is>
__host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Sequence<Is...>)
{
// not implemented
return Sequence<>{};
}
#endif
__host__
__device__
static
constexpr
index_t
__host__
__device__
static
constexpr
index_t
GetOffsetFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
GetOffsetFromMultiIndex
(
Array
<
index_t
,
nDim
>
multi_id
)
{
{
...
...
src/include/Sequence.hip.hpp
View file @
709f13a6
...
@@ -37,10 +37,11 @@ struct Sequence
...
@@ -37,10 +37,11 @@ struct Sequence
template <class MapOld2New>
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
{
{
#if 0
static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType,
static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType,
arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
"wrong! invalid old2new map");
"wrong! invalid old2new map");
#endif
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
return ReorderGivenNew2Old(map_new2old);
return ReorderGivenNew2Old(map_new2old);
...
@@ -99,6 +100,7 @@ struct Sequence
...
@@ -99,6 +100,7 @@ struct Sequence
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
);
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
);
};
};
// merge sequence
template
<
class
,
class
>
template
<
class
,
class
>
struct
sequence_merge
;
struct
sequence_merge
;
...
@@ -108,6 +110,7 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
...
@@ -108,6 +110,7 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using
SeqType
=
Sequence
<
Xs
...,
Ys
...
>
;
using
SeqType
=
Sequence
<
Xs
...,
Ys
...
>
;
};
};
// arithmetic sqeuence
template
<
index_t
IBegin
,
index_t
NSize
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
NSize
,
index_t
Increment
>
struct
arithmetic_sequence_gen_impl
struct
arithmetic_sequence_gen_impl
{
{
...
@@ -139,7 +142,31 @@ struct arithmetic_sequence_gen
...
@@ -139,7 +142,31 @@ struct arithmetic_sequence_gen
typename
arithmetic_sequence_gen_impl
<
IBegin
,
IEnd
-
IBegin
,
Increment
>::
SeqType
;
typename
arithmetic_sequence_gen_impl
<
IBegin
,
IEnd
-
IBegin
,
Increment
>::
SeqType
;
};
};
// reverse scan with init
// transform sequence
template
<
class
,
class
>
struct
sequence_transform
;
template
<
class
F
,
index_t
...
Is
>
struct
sequence_transform
<
F
,
Sequence
<
Is
...
>>
{
using
SeqType
=
Sequence
<
F
{}(
Is
)...
>
;
};
// uniform sequence
template
<
index_t
NSize
,
index_t
I
>
struct
uniform_sequence_gen
{
struct
return_constant
{
__host__
__device__
constexpr
index_t
operator
()(
index_t
)
const
{
return
I
;
}
};
using
SeqType
=
typename
sequence_transform
<
return_constant
,
typename
arithmetic_sequence_gen
<
0
,
NSize
,
1
>::
SeqType
>::
SeqType
;
};
// reverse inclusive scan (with init) sequence
template
<
class
,
class
,
index_t
>
template
<
class
,
class
,
index_t
>
struct
sequence_reverse_inclusive_scan
;
struct
sequence_reverse_inclusive_scan
;
...
@@ -166,22 +193,7 @@ struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
...
@@ -166,22 +193,7 @@ struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
using
SeqType
=
Sequence
<>
;
using
SeqType
=
Sequence
<>
;
};
};
#if 0
// extract sequence
// reverse scan with token
template <class, class, index_t>
struct sequence_reverse_inclusive_token_scan;
template <index_t I, index_t... Is, class F, index_t Token>
struct sequence_reverse_inclusive_token_scan<Sequence<I, Is...>, F, Token>
{
using old_scan = typename sequence_reverse_inclusive_token_scan<Sequence<Is...>, F, Token>::SeqType;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
};
#endif
template
<
class
,
class
>
template
<
class
,
class
>
struct
sequence_extract
;
struct
sequence_extract
;
...
@@ -191,6 +203,7 @@ struct sequence_extract<Seq, Sequence<Is...>>
...
@@ -191,6 +203,7 @@ struct sequence_extract<Seq, Sequence<Is...>>
using
SeqType
=
Sequence
<
Seq
{}.
Get
(
Number
<
Is
>
{})...
>
;
using
SeqType
=
Sequence
<
Seq
{}.
Get
(
Number
<
Is
>
{})...
>
;
};
};
// split sequence
template
<
class
Seq
,
index_t
I
>
template
<
class
Seq
,
index_t
I
>
struct
sequence_split
struct
sequence_split
{
{
...
@@ -203,6 +216,7 @@ struct sequence_split
...
@@ -203,6 +216,7 @@ struct sequence_split
using
SeqType1
=
typename
sequence_extract
<
Seq
,
range1
>::
SeqType
;
using
SeqType1
=
typename
sequence_extract
<
Seq
,
range1
>::
SeqType
;
};
};
// reverse sequence
template
<
class
Seq
>
template
<
class
Seq
>
struct
sequence_reverse
struct
sequence_reverse
{
{
...
@@ -308,8 +322,10 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys.
...
@@ -308,8 +322,10 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys.
{
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
#if 0
static_for<0, seq_x.GetSize(), 1>{}(
static_for<0, seq_x.GetSize(), 1>{}(
[&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); });
[&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); });
#endif
return
Sequence
<
(
Xs
-
Ys
)...
>
{};
return
Sequence
<
(
Xs
-
Ys
)...
>
{};
}
}
...
@@ -388,10 +404,12 @@ __host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
...
@@ -388,10 +404,12 @@ __host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
{
{
constexpr
auto
seq_x
=
Sequence
<
Xs
...
>
{};
constexpr
auto
seq_x
=
Sequence
<
Xs
...
>
{};
#if 0
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
constexpr auto I = decltype(Iter){};
constexpr auto I = decltype(Iter){};
static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow");
static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow");
});
});
#endif
return
Sequence
<
(
Y
-
Xs
)...
>
{};
return
Sequence
<
(
Y
-
Xs
)...
>
{};
}
}
...
...
src/include/blockwise_generic_tensor_slice_op.hip.hpp
View file @
709f13a6
...
@@ -256,6 +256,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -256,6 +256,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
make_ConstantTensorDescriptor_packed
(
thread_sub_tensor_lengths
*
repeat_lengths
);
make_ConstantTensorDescriptor_packed
(
thread_sub_tensor_lengths
*
repeat_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
#if 0
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto clipboard_data_multi_id_begin =
const auto clipboard_data_multi_id_begin =
...
@@ -269,6 +270,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
...
@@ -269,6 +270,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
dst_data_multi_id_begin); // cannot not constexpr, why?
dst_data_multi_id_begin); // cannot not constexpr, why?
#else
constexpr
auto
clipboard_data_multi_id_begin
=
repeat_multi_id_
*
thread_sub_tensor_lengths
;
constexpr
auto
dst_data_multi_id_begin
=
repeat_multi_id_
*
data_per_cluster_per_dims
;
constexpr
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id_begin
);
constexpr
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id_begin
);
#endif
threadwise_generic_tensor_slice_copy_v1
(
thread_tensor_desc
,
threadwise_generic_tensor_slice_copy_v1
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
p_clipboard
+
clipboard_offset
,
...
...
src/include/common.hip.hpp
View file @
709f13a6
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "Array.hip.hpp"
#include "Array.hip.hpp"
#include "functional.hip.hpp"
#include "functional.hip.hpp"
#include "functional2.hip.hpp"
#include "functional2.hip.hpp"
#include "functional3.hip.hpp"
#if USE_AMD_INLINE_ASM
#if USE_AMD_INLINE_ASM
#include "amd_inline_asm.hip.hpp"
#include "amd_inline_asm.hip.hpp"
...
...
src/include/functional.hip.hpp
View file @
709f13a6
#pragma once
#pragma once
#include "integral_constant.hip.hpp"
#include "integral_constant.hip.hpp"
#include "Sequence.hip.hpp"
struct
forwarder
struct
forwarder
{
{
...
@@ -10,6 +11,14 @@ struct forwarder
...
@@ -10,6 +11,14 @@ struct forwarder
}
}
};
};
struct
swallow
{
template
<
class
...
Ts
>
__host__
__device__
constexpr
swallow
(
Ts
&&
...
ts
)
{
}
};
#if 0
#if 0
template<class F>
template<class F>
__host__ __device__ constexpr auto unpacker(F f)
__host__ __device__ constexpr auto unpacker(F f)
...
@@ -72,51 +81,6 @@ struct static_if<false>
...
@@ -72,51 +81,6 @@ struct static_if<false>
return
Type
{};
return
Type
{};
}
}
};
};
template
<
index_t
Iter
,
index_t
Remaining
,
index_t
Increment
>
struct
static_for_impl
{
template
<
class
F
>
constexpr
__host__
__device__
void
operator
()(
F
f
)
const
{
static_assert
(
Remaining
%
Increment
==
0
,
"wrong! Remaining % Increment != 0"
);
static_assert
(
Increment
<=
Remaining
,
"will go out-of-range"
);
f
(
Number
<
Iter
>
{});
static_for_impl
<
Iter
+
Increment
,
Remaining
-
Increment
,
Increment
>
{}(
f
);
}
};
template
<
index_t
Iter
,
index_t
Increment
>
struct
static_for_impl
<
Iter
,
0
,
Increment
>
{
template
<
class
F
>
constexpr
__host__
__device__
void
operator
()(
F
)
const
{
// no work left, just return
return
;
}
};
// F signature: F(Number<Iter>)
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
struct
static_for
{
template
<
class
F
>
constexpr
__host__
__device__
void
operator
()(
F
f
)
const
{
static_assert
(
NBegin
<=
NEnd
,
"wrongs! should have NBegin <= NEnd"
);
static_assert
((
NEnd
-
NBegin
)
%
Increment
==
0
,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
#if 0
static_if<(NBegin < NEnd)>{}(
[&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
#else
static_for_impl
<
NBegin
,
NEnd
-
NBegin
,
Increment
>
{}(
f
);
#endif
}
};
template
<
index_t
NLoop
>
template
<
index_t
NLoop
>
struct
static_const_reduce_n
struct
static_const_reduce_n
...
...
src/include/functional2.hip.hpp
View file @
709f13a6
#pragma once
#pragma once
#include "functional.hip.hpp"
#include "Sequence.hip.hpp"
#include "Sequence.hip.hpp"
// RemainLengths: Sequence<...>
#if 0
template
<
class
RemainLengths
>
template <
index_t Iter, index_t Remaining, index_t Increment
>
struct
static_for
d
_impl
struct static_for_impl
{
{
// F signature: F(Sequence<...> multi_id)
template <class F>
// CurrentMultiIndex: Sequence<...>
constexpr __host__ __device__ void operator()(F f) const
template
<
class
F
,
class
CurrentMultiIndex
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
)
const
{
{
static_assert
(
RemainLengths
::
GetSize
()
>
0
,
"wrong! should not get here"
);
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
static_assert(Increment <= Remaining, "will go out-of-range");
static_for
<
0
,
RemainLengths
::
Front
(),
1
>
{}([
=
](
auto
I
)
{
static_ford_impl
<
decltype
(
RemainLengths
::
PopFront
())
>
{}(
f
,
CurrentMultiIndex
::
PushBack
(
I
));
});
}
};
template
<
>
f(Number<Iter>{});
struct
static_ford_impl
<
Sequence
<>>
static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
)
const
{
f
(
CurrentMultiIndex
{});
}
}
};
};
// Lengths is Sequence<...>
template <index_t Iter, index_t Increment>
template
<
class
Lengths
>
struct static_for_impl<Iter, 0, Increment>
struct
static_ford
{
{
// F signature: F(Sequence<...> multi_id)
template <class F>
template <class F>
__host__
__device__
void
operator
()(
F
f
)
const
constexpr
__host__ __device__ void operator()(F) const
{
{
static_assert
(
Lengths
::
GetSize
()
>
0
,
"wrong! Lengths is empty"
);
// no work left, just return
return;
static_ford_impl
<
Lengths
>
{}(
f
,
Sequence
<>
{});
}
}
};
};
template
<
index_t
RemainDim
>
// F signature: F(Number<Iter>)
struct
ford_impl
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
{
// F signature: F(Array<...> multi_id)
template <class F>
// CurrentMultiIndex: Array<...>
constexpr __host__ __device__ void operator()(F f) const
// RemainLengths: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
,
class
RemainLengths
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
current_multi_id
,
RemainLengths
)
const
{
{
static_assert
(
RemainLengths
::
GetSize
()
==
RemainDim
,
"wrong!"
);
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
static_assert
(
RemainDim
>
1
,
"wrong!"
);
constexpr
auto
next_length
=
RemainLengths
{}.
Front
();
static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
for
(
index_t
i
=
0
;
i
<
next_length
;
++
i
)
#if 0
{
static_if<(NBegin < NEnd)>{}(
ford_impl
<
RemainDim
-
1
>
{}(
f
,
current_multi_id
.
PushBack
(
i
),
RemainLengths
{}.
PopFront
());
[&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
}
#else
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
#endif
}
}
}
;
}
;
#else
template
<
class
>
struct
static_for_impl
;
template
<
>
template
<
index_t
...
Is
>
struct
for
d
_impl
<
1
>
struct
static_
for_impl
<
Sequence
<
Is
...
>
>
{
{
// F signature: F(Array<...> multi_id)
template
<
class
F
>
// CurrentMultiIndex: Array<...>
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
// RemainLengths: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
,
class
RemainLengths
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
current_multi_id
,
RemainLengths
)
const
{
{
static_assert
(
RemainLengths
::
GetSize
()
==
1
,
"wrong!"
);
swallow
{(
f
(
Number
<
Is
>
{}),
0
)...};
constexpr
index_t
last_length
=
RemainLengths
{}.
Front
();
for
(
index_t
i
=
0
;
i
<
last_length
;
++
i
)
{
f
(
current_multi_id
.
PushBack
(
i
));
}
}
}
};
};
//
Lengths is Sequence<...>
//
F signature: F(Number<Iter>)
template
<
class
Lengths
>
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
struct
for
d
struct
static_
for
{
{
// F signature: F(Array<...> multi_id)
template
<
class
F
>
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
{
{
constexpr
index_t
first_length
=
Lengths
{}.
Front
();
static_assert
(
NBegin
<=
NEnd
,
"wrongs! should have NBegin <= NEnd"
);
static_assert
((
NEnd
-
NBegin
)
%
Increment
==
0
,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
for
(
index_t
i
=
0
;
i
<
first_length
;
++
i
)
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
SeqType
>
{}(
f
);
{
ford_impl
<
Lengths
::
GetSize
()
-
1
>
{}(
f
,
Array
<
index_t
,
1
>
{
i
},
Lengths
{}.
PopFront
());
}
}
}
};
};
#endif
src/include/functional3.hip.hpp
0 → 100644
View file @
709f13a6
#pragma once
#include "functional.hip.hpp"
#include "functional2.hip.hpp"
#include "Sequence.hip.hpp"
#include "Array.hip.hpp"
// RemainLengths: Sequence<...>
template
<
class
RemainLengths
>
struct
static_ford_impl
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
)
const
{
static_assert
(
RemainLengths
::
GetSize
()
>
0
,
"wrong! should not get here"
);
static_for
<
0
,
RemainLengths
::
Front
(),
1
>
{}([
=
](
auto
I
)
{
static_ford_impl
<
decltype
(
RemainLengths
::
PopFront
())
>
{}(
f
,
CurrentMultiIndex
::
PushBack
(
I
));
});
}
};
template
<
>
struct
static_ford_impl
<
Sequence
<>>
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
)
const
{
f
(
CurrentMultiIndex
{});
}
};
// Lengths is Sequence<...>
template
<
class
Lengths
>
struct
static_ford
{
// F signature: F(Sequence<...> multi_id)
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
{
static_assert
(
Lengths
::
GetSize
()
>
0
,
"wrong! Lengths is empty"
);
static_ford_impl
<
Lengths
>
{}(
f
,
Sequence
<>
{});
}
};
template
<
index_t
RemainDim
>
struct
ford_impl
{
// F signature: F(Array<...> multi_id)
// CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
,
class
RemainLengths
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
current_multi_id
,
RemainLengths
)
const
{
static_assert
(
RemainLengths
::
GetSize
()
==
RemainDim
,
"wrong!"
);
static_assert
(
RemainDim
>
1
,
"wrong!"
);
constexpr
auto
next_length
=
RemainLengths
{}.
Front
();
for
(
index_t
i
=
0
;
i
<
next_length
;
++
i
)
{
ford_impl
<
RemainDim
-
1
>
{}(
f
,
current_multi_id
.
PushBack
(
i
),
RemainLengths
{}.
PopFront
());
}
}
};
template
<
>
struct
ford_impl
<
1
>
{
// F signature: F(Array<...> multi_id)
// CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
,
class
RemainLengths
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
current_multi_id
,
RemainLengths
)
const
{
static_assert
(
RemainLengths
::
GetSize
()
==
1
,
"wrong!"
);
constexpr
index_t
last_length
=
RemainLengths
{}.
Front
();
for
(
index_t
i
=
0
;
i
<
last_length
;
++
i
)
{
f
(
current_multi_id
.
PushBack
(
i
));
}
}
};
// Lengths is Sequence<...>
template
<
class
Lengths
>
struct
ford
{
// F signature: F(Array<...> multi_id)
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
{
constexpr
index_t
first_length
=
Lengths
{}.
Front
();
for
(
index_t
i
=
0
;
i
<
first_length
;
++
i
)
{
ford_impl
<
Lengths
::
GetSize
()
-
1
>
{}(
f
,
Array
<
index_t
,
1
>
{
i
},
Lengths
{}.
PopFront
());
}
}
};
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
View file @
709f13a6
...
@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// choose GEMM implementation here
// choose GEMM implementation here
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
#if
0
#if
1
return
blockwise_gemm
.
Run
(
Xs
...);
return
blockwise_gemm
.
Run
(
Xs
...);
#else
#else
return
blockwise_gemm
.
Run_asm
(
Xs
...);
return
blockwise_gemm
.
Run_asm
(
Xs
...);
...
...
src/include/threadwise_generic_tensor_slice_op.hip.hpp
View file @
709f13a6
...
@@ -77,14 +77,14 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -77,14 +77,14 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
});
});
#else
#else
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
_
)
{
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_multi_id
)
{
const
auto
access_multi_id
=
sequence2array
(
access_multi_id_
)
;
const
expr
index_t
itmp
=
access_multi_id
.
Back
()
*
DataPerAccess
;
auto
data_multi_id_in_access_order
=
access_multi_id
;
constexpr
auto
data_multi_id_in_access_order
=
data_multi_id_in_access_order
[
nDim
-
1
]
=
access_multi_id
[
nDim
-
1
]
*
DataPerAccess
;
access_multi_id
.
Modify
(
Number
<
nDim
-
1
>
{},
Number
<
itmp
>
{})
;
const
auto
data_multi_id
=
const
expr
auto
data_multi_id
=
reorder_array_given_old2new
(
reorder_array_given_old2new
(
data_multi_id_in_access_order
,
DimAccessOrder
{});
sequence2array
(
data_multi_id_in_access_order
)
,
DimAccessOrder
{});
const
index_t
src_index
=
const
index_t
src_index
=
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
SrcDesc
::
GetOffsetFromMultiIndex
(
src_multi_id_begin
+
data_multi_id
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment