Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
7a7fe160
Commit
7a7fe160
authored
Sep 09, 2019
by
Chao Liu
Browse files
more utility code
parent
625838de
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
809 additions
and
440 deletions
+809
-440
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
..._convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
+54
-37
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+2
-2
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+4
-4
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+100
-62
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+1
-1
composable_kernel/include/utility/Array.hpp
composable_kernel/include/utility/Array.hpp
+116
-143
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+162
-98
composable_kernel/include/utility/array_helper.hpp
composable_kernel/include/utility/array_helper.hpp
+93
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+5
-0
composable_kernel/include/utility/functional.hpp
composable_kernel/include/utility/functional.hpp
+25
-5
composable_kernel/include/utility/functional2.hpp
composable_kernel/include/utility/functional2.hpp
+6
-1
composable_kernel/include/utility/functional3.hpp
composable_kernel/include/utility/functional3.hpp
+28
-38
composable_kernel/include/utility/functional4.hpp
composable_kernel/include/utility/functional4.hpp
+34
-0
composable_kernel/include/utility/integral_constant.hpp
composable_kernel/include/utility/integral_constant.hpp
+0
-49
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+12
-0
composable_kernel/include/utility/number.hpp
composable_kernel/include/utility/number.hpp
+44
-0
composable_kernel/include/utility/sequence_helper.hpp
composable_kernel/include/utility/sequence_helper.hpp
+46
-0
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+32
-0
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+41
-0
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
...ce_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
+4
-0
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
View file @
7a7fe160
...
@@ -47,6 +47,19 @@ template <index_t GridSize,
...
@@ -47,6 +47,19 @@ template <index_t GridSize,
index_t
OutThreadCopyDataPerAccess_N
>
index_t
OutThreadCopyDataPerAccess_N
>
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
I10
=
Number
<
10
>
{};
static
constexpr
auto
I11
=
Number
<
11
>
{};
#if 0
#if 0
__device__ void Run(const Float* const __restrict__ p_in_global,
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
const Float* const __restrict__ p_wei_global,
...
@@ -60,11 +73,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
...
@@ -60,11 +73,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
GemmNPerThreadSubC % NPerThread == 0)),
GemmNPerThreadSubC % NPerThread == 0)),
"wrong!");
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto False = integral_constant<bool, false>{};
...
@@ -487,58 +495,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
...
@@ -487,58 +495,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
Float
*
const
__restrict__
p_out_global
)
const
Float
*
const
__restrict__
p_out_global
)
const
{
{
#if 0
#if 0
constexpr auto tmp = std::tuple<bool>{};
constexpr auto a = make_tuple(true, Sequence<1>{}, index_t(99));
constexpr auto flag = std::get<0>(tmp);
#else
constexpr
auto
a
=
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
{},
99
);
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
printf
(
"
adsas
%d
\n
"
,
a
.
At
(
Number
<
0
>
{}
));
printf("
[0]
%d\n", a.At(
I0
));
print_Sequence
(
"
seq
"
,
a
.
At
(
Number
<
1
>
{}
));
print_Sequence("
[1]
", a.At(
I1
));
printf
(
"
adsas
%lu
\n
"
,
a
.
At
(
Number
<
2
>
{}
));
printf("
[2]
%lu\n", a.At(
I2
));
}
}
auto
b
=
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
{},
99
)
;
bool flag = true
;
b
.
At
(
Number
<
0
>
{})
=
false
;
auto b = make_tuple(flag, Sequence<1>{}, 99);
b.At(I0) = false;
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
printf
(
"adsas %d
\n
"
,
b
.
At
(
Number
<
0
>
{}));
printf("[0] %d\n", b.At(I0));
print_Sequence
(
"seq"
,
b
.
At
(
Number
<
1
>
{}));
print_Sequence("[1]", b.At(I1));
printf
(
"adsas %lu
\n
"
,
b
.
At
(
Number
<
2
>
{}));
printf("[2] %lu\n", b.At(I2));
printf("flag %d\n", flag);
}
}
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
printf
(
"adsas %d
\n
"
,
printf("[0] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I0));
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
(),
99
).
At
(
Number
<
0
>
{}));
print_Sequence("[1]", make_tuple(true, Sequence<1>(), index_t(99)).At(I1));
print_Sequence
(
printf("[2] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I2));
"seq"
,
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
(),
99
).
At
(
Number
<
1
>
{}));
printf
(
"adsas %d
\n
"
,
Tuple
<
bool
,
Sequence
<
1
>
,
index_t
>
(
true
,
Sequence
<
1
>
(),
99
).
At
(
Number
<
2
>
{}));
}
}
#endif
#elif
1
#if 0
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// create a native tensor descriptor
// create a native tensor descriptor
constexpr auto in_
n_
c_h_w_global_desc =
constexpr
auto
in_c_h_w_
n_
global_desc
=
make_NativeTensorDescriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
make_NativeTensorDescriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_c_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_c_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
in_c_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
auto
pad_h_w
=
Pad
<
Sequence
<
Hi
,
Wi
>
,
LowerPads
,
UpperPads
>
{};
constexpr
auto
pass_c
=
PassThrough
<
C
>
{};
constexpr
auto
pass_n
=
PassThrough
<
N
>
{};
constexpr
auto
trans
=
make_tuple
(
pass_c
,
pad_h_w
,
pass_n
);
constexpr
auto
lower_dim_groups
=
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{});
constexpr
auto
upper_dim_groups
=
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{});
constexpr
auto
in_c_h_w_n_padded_global_desc
=
transform_tensor_descriptor
(
in_c_h_w_n_global_desc
,
trans
,
lower_dim_groups
,
upper_dim_groups
);
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
{
print_tensor_descriptor("in_n_c_h_w_global_desc", in_n_c_h_w_global_desc);
print_tensor_descriptor
(
"in_c_h_w_n_global_desc"
,
in_c_h_w_n_global_desc
);
}
// transform the tensor descriptor once
printf
(
"offset: %lu
\n
"
,
in_c_h_w_n_global_desc
.
GetOffset
({
1
,
2
,
3
,
4
}));
//
// calculate the offset of some entry
printf
(
"padded offset: %lu
\n
"
,
in_c_h_w_n_padded_global_desc
.
GetOffset
({
1
,
4
,
5
,
4
}));
}
#endif
#endif
}
}
#endif
#endif
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
7a7fe160
...
@@ -178,7 +178,7 @@ struct ConstantTensorDescriptor
...
@@ -178,7 +178,7 @@ struct ConstantTensorDescriptor
{
{
constexpr
auto
IDim
=
IDim_
{};
constexpr
auto
IDim
=
IDim_
{};
constexpr
index_t
stride
=
PackedStrides
::
Get
(
IDim
);
constexpr
index_t
stride
=
PackedStrides
::
Get
(
IDim
);
multi_id
.
Set
(
IDim
,
id
/
stride
)
;
multi_id
(
IDim
)
=
id
/
stride
;
id
-=
multi_id
[
IDim
]
*
stride
;
id
-=
multi_id
[
IDim
]
*
stride
;
}
}
};
};
...
@@ -192,7 +192,7 @@ struct ConstantTensorDescriptor
...
@@ -192,7 +192,7 @@ struct ConstantTensorDescriptor
// calculate index in each of the dimensions in the order of their dimension
// calculate index in each of the dimensions in the order of their dimension
static_for
<
0
,
nDim
-
1
,
1
>
{}(
lambda_GetMultiIndexFrom1dIndex
<
PackedStrides
>
(
id
,
multi_id
));
static_for
<
0
,
nDim
-
1
,
1
>
{}(
lambda_GetMultiIndexFrom1dIndex
<
PackedStrides
>
(
id
,
multi_id
));
multi_id
.
Set
(
Number
<
nDim
-
1
>
{}
,
id
/
PackedStrides
::
Get
(
Number
<
nDim
-
1
>
{})
)
;
multi_id
(
Number
<
nDim
-
1
>
{}
)
=
id
/
PackedStrides
::
Get
(
Number
<
nDim
-
1
>
{});
return
multi_id
;
return
multi_id
;
}
}
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
7a7fe160
...
@@ -33,7 +33,7 @@ struct PassThrough
...
@@ -33,7 +33,7 @@ struct PassThrough
};
};
// LowLengths: Sequence<...>
// LowLengths: Sequence<...>
template
<
class
LowLengths
,
class
LeftPads
,
class
RightPads
>
template
<
typename
LowLengths
,
typename
LeftPads
,
typename
RightPads
>
struct
Pad
struct
Pad
{
{
static
constexpr
index_t
nDim
=
LowLengths
::
GetSize
();
static
constexpr
index_t
nDim
=
LowLengths
::
GetSize
();
...
@@ -67,7 +67,7 @@ struct Pad
...
@@ -67,7 +67,7 @@ struct Pad
#if 0
#if 0
// LowLengths: Sequence<...>
// LowLengths: Sequence<...>
template <
class
LowLengths>
template <
typename
LowLengths>
struct Merge
struct Merge
{
{
static constexpr index_t nDimLow = LowLengths::GetSize();
static constexpr index_t nDimLow = LowLengths::GetSize();
...
@@ -113,7 +113,7 @@ struct Merge
...
@@ -113,7 +113,7 @@ struct Merge
#endif
#endif
// UpLengths: Sequence<...>
// UpLengths: Sequence<...>
template
<
index_t
LowLength
,
class
UpLengths
>
template
<
index_t
LowLength
,
typename
UpLengths
>
struct
Unmerge
struct
Unmerge
{
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimLow
=
1
;
...
@@ -161,7 +161,7 @@ struct Unmerge
...
@@ -161,7 +161,7 @@ struct Unmerge
// UpLengths: Sequence<...>
// UpLengths: Sequence<...>
// Coefficients: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template
<
index_t
LowLength
,
class
UpLengths
,
class
Coefficients
>
template
<
index_t
LowLength
,
typename
UpLengths
,
typename
Coefficients
>
struct
Embed
struct
Embed
{
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimLow
=
1
;
...
...
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
7a7fe160
...
@@ -7,12 +7,12 @@
...
@@ -7,12 +7,12 @@
namespace
ck
{
namespace
ck
{
template
<
class
...
NativeDimensions
>
template
<
typename
...
NativeDimensions
>
struct
NativeTensorDescriptor
struct
NativeTensorDescriptor
{
{
using
type
=
NativeTensorDescriptor
;
using
type
=
NativeTensorDescriptor
;
static
constexpr
auto
mDimensions
=
Tuple
<
NativeDimensions
...
>
{}
;
static
constexpr
index_t
nDim
=
sizeof
...(
NativeDimensions
)
;
static
constexpr
index_t
nDim
=
m
Dimensions
.
GetSize
(
);
static
constexpr
auto
mDimensions
=
make_tuple
(
Native
Dimensions
{}...
);
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
...
@@ -20,7 +20,7 @@ struct NativeTensorDescriptor
...
@@ -20,7 +20,7 @@ struct NativeTensorDescriptor
struct
lambda_GetLength
struct
lambda_GetLength
{
{
template
<
class
IDim
>
template
<
typename
IDim
>
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
{
{
return
GetLength
(
IDim
{});
return
GetLength
(
IDim
{});
...
@@ -34,7 +34,7 @@ struct NativeTensorDescriptor
...
@@ -34,7 +34,7 @@ struct NativeTensorDescriptor
struct
lambda_GetStride
struct
lambda_GetStride
{
{
template
<
class
IDim
>
template
<
typename
IDim
>
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
{
{
return
GetStride
(
IDim
{});
return
GetStride
(
IDim
{});
...
@@ -49,16 +49,16 @@ struct NativeTensorDescriptor
...
@@ -49,16 +49,16 @@ struct NativeTensorDescriptor
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
{
return
mDimensions
.
Ge
t
(
Number
<
IDim
>
{}).
GetLength
();
return
mDimensions
.
A
t
(
Number
<
IDim
>
{}).
GetLength
();
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
{
return
mDimensions
.
Ge
t
(
Number
<
IDim
>
{}).
GetStride
();
return
mDimensions
.
A
t
(
Number
<
IDim
>
{}).
GetStride
();
}
}
__host__
__device__
static
constexpr
index_t
GetOffset
(
Index
idx
)
__host__
__device__
static
constexpr
index_t
GetOffset
(
const
Index
&
idx
)
{
{
index_t
offset
=
0
;
index_t
offset
=
0
;
...
@@ -67,7 +67,7 @@ struct NativeTensorDescriptor
...
@@ -67,7 +67,7 @@ struct NativeTensorDescriptor
return
offset
;
return
offset
;
}
}
__host__
__device__
static
constexpr
index_t
GetOffsetDiff
(
Index
idx_diff
)
__host__
__device__
static
constexpr
index_t
GetOffsetDiff
(
const
Index
&
idx_diff
)
{
{
index_t
offset_diff
=
0
;
index_t
offset_diff
=
0
;
...
@@ -96,28 +96,65 @@ struct NativeTensorDescriptor
...
@@ -96,28 +96,65 @@ struct NativeTensorDescriptor
}
}
};
};
#if 0
// LowerTensorDescriptor
// LowerTensorDescriptor
// Transforms: std::tuple<DimensionTransforms...>
// Transforms: Tuple<DimensionTransforms...>
// LowerDimensionIds: std::tuple<Sequence<...>>
// LowerDimensionIds: Tuple<Sequence<...>>
// UpperDimensionIds: std::tuple<Sequence<...>>
// UpperDimensionIds: Tuple<Sequence<...>>
template <class LowTensorDescriptor, class Transforms, class LowDimensionIds, class UpDimensionIds>
template
<
typename
LowTensorDescriptor
,
typename
Transforms
,
typename
LowDimensionIds
,
typename
UpDimensionIds
>
struct
TransformedTensorDescriptor
struct
TransformedTensorDescriptor
{
{
using
type
=
TransformedTensorDescriptor
;
using
type
=
TransformedTensorDescriptor
;
static constexpr index_t nDimUp = GetUpperNumOfDimension();
static
constexpr
index_t
nTransform
=
Transforms
::
Size
();
static constexpr index_t nDimLow = GetLowerNumOfDimension();
struct
lambda_merge_sequences
{
template
<
typename
...
Seqs
>
__host__
__device__
constexpr
auto
operator
()(
Seqs
...
seqs
)
const
{
return
merge_sequences
(
seqs
...);
}
};
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
static constexpr index_t nTransform = Transforms::GetSize();
using
duplicated_low_active_dims
=
decltype
(
unpack
(
lambda_merge_sequences
{},
LowDimensionIds
{}));
using
low_active_dims
=
typename
sequence_unique_sort
<
duplicated_low_active_dims
,
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
low_active_dims
::
Size
();
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
using
duplicated_up_active_dims
=
decltype
(
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{}));
using
up_active_dims
=
typename
sequence_unique_sort
<
duplicated_up_active_dims
,
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
up_active_dims
::
Size
();
}
static
constexpr
index_t
nDimUp
=
GetNumOfUpperDimension
();
static
constexpr
index_t
nDimLow
=
GetNumOfLowerDimension
();
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
__host__ __device__
static
constexpr TransformedTensorDescriptor()
__host__
__device__
constexpr
TransformedTensorDescriptor
()
{
{
static_assert(nTransform == Transforms::GetSize() &&
static_assert
(
nTransform
==
Transforms
::
Size
()
&&
nTransform
==
LowDimensionIds
::
Size
()
&&
nTransform == LowDimensionIds::GetSize() &&
nTransform
==
UpDimensionIds
::
Size
(),
nTransform == UpDimensionIds::GetSize(),
"wrong! # of transformations not the same"
);
"wrong! # of transformations not the same"
);
// TODO: sanity check: LowDimensionIds should include all low-dimensions,
// TODO: sanity check: LowDimensionIds should include all low-dimensions,
...
@@ -128,33 +165,17 @@ struct TransformedTensorDescriptor
...
@@ -128,33 +165,17 @@ struct TransformedTensorDescriptor
// a low-dimension should be associated with only one transformation
// a low-dimension should be associated with only one transformation
}
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
constexpr auto low_active_dims = unique_sort_sequence(
merge_tuple_of_sequences(LowDimensionIds{}), math::less<index_t>{});
return low_active_dims.GetSize();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
constexpr auto up_active_dims =
unique_sort_sequence(merge_tuple_of_sequences(UpDimensionIds{}), math::less<index_t>{});
return up_active_dims.GetSize();
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
{
return
GetNumOfUpperDimension
();
return
GetNumOfUpperDimension
();
}
}
__host__ __device__ static constexpr auto GetLengths()
#if 0
__host__ __device__ static constexpr auto GetUpperLengths()
{
{
struct lambda_get_upper_lengths
struct lambda_get_upper_lengths
{
{
template <
class
Transform>
template <
typename
Transform>
__host__ __device__ constexpr auto operator()(Transform tran) const
__host__ __device__ constexpr auto operator()(Transform tran) const
{
{
return tran.GetUpperLengths();
return tran.GetUpperLengths();
...
@@ -173,6 +194,7 @@ struct TransformedTensorDescriptor
...
@@ -173,6 +194,7 @@ struct TransformedTensorDescriptor
using sort_dimension_ids =
using sort_dimension_ids =
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
...
@@ -182,46 +204,48 @@ struct TransformedTensorDescriptor
...
@@ -182,46 +204,48 @@ struct TransformedTensorDescriptor
return sorted_upper_lengths;
return sorted_upper_lengths;
}
}
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
#endif
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
{
{
return
LowTensorDescriptor
{};
return
LowTensorDescriptor
{};
}
}
__host__ __device__ static constexpr
i
ndex
_t
GetLowerIndex(UpperIndex idx_up)
__host__
__device__
static
constexpr
LowerI
ndex
GetLowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
LowerIndex
idx_low
;
LowerIndex
idx_low
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr auto tran = Transforms
::Ge
t(itran);
constexpr
auto
tran
=
Transforms
{}.
A
t
(
itran
);
constexpr
auto idx_low_part = pick_array_element(idx_low, LowDimensionIds
::Ge
t(itran));
auto
idx_low_part
=
pick_array_element
(
idx_low
,
LowDimensionIds
{}.
A
t
(
itran
));
const
expr
auto idx_up_part
= pick_array_element(idx_up, UpDimensionIds
::Ge
t(itran));
const
auto
idx_up_part
=
pick_array_element
(
idx_up
,
UpDimensionIds
{}.
A
t
(
itran
));
// this assume each lower (single) index is only assocaited with one transformation,
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
// of TransformedTensorDescriptor
idx_low_part = tran.GetLowerIndex(idx_up_part);
idx_low_part
=
tran
.
GetLowerIndex
(
to_array
(
idx_up_part
)
)
;
});
});
return
idx_low
;
return
idx_low
;
}
}
__host__ __device__ static constexpr
i
ndex
_t
GetLowerIndexDiff(UpperIndex idx_up_diff,
__host__
__device__
static
constexpr
LowerI
ndex
GetLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
LowerIndex idx_low_old)
const
LowerIndex
&
idx_low_old
)
{
{
LowerIndex
idx_low_diff
;
LowerIndex
idx_low_diff
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr auto tran = Transforms::
Ge
t(itran);
constexpr
auto
tran
=
Transforms
::
A
t
(
itran
);
const
expr
auto idx_up_diff_part =
const
auto
idx_up_diff_part
=
pick_array_element(idx_up_diff, UpDimensionIds::
Ge
t(itran));
pick_array_element
(
idx_up_diff
,
UpDimensionIds
::
A
t
(
itran
));
constexpr auto idx_low_diff_part =
auto
idx_low_diff_part
=
pick_array_element
(
idx_low_diff
,
LowDimensionIds
::
At
(
itran
));
pick_array_element(idx_low_diff, LowDimensionIds::Get(itran));
const
expr
auto idx_low_old_part =
const
auto
idx_low_old_part
=
pick_array_element(idx_low_old, LowDimensionIds::
Ge
t(itran));
pick_array_element
(
idx_low_old
,
LowDimensionIds
::
A
t
(
itran
));
// this assume each lower (single) index is associated with only one transformation,
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// which is required for index transformation, and has been checked during constructor
...
@@ -232,13 +256,14 @@ struct TransformedTensorDescriptor
...
@@ -232,13 +256,14 @@ struct TransformedTensorDescriptor
return
idx_low_diff
;
return
idx_low_diff
;
}
}
__host__ __device__ static constexpr index_t GetOffset(UpperIndex idx_up)
__host__
__device__
static
constexpr
index_t
GetOffset
(
const
UpperIndex
&
idx_up
)
{
{
return
GetLowerTensorDescriptor
().
GetOffset
(
GetLowerIndex
(
idx_up
));
return
GetLowerTensorDescriptor
().
GetOffset
(
GetLowerIndex
(
idx_up
));
}
}
#if 0
template <index_t IDim>
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
;
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
{
// not implemented
// not implemented
}
}
...
@@ -257,8 +282,8 @@ struct TransformedTensorDescriptor
...
@@ -257,8 +282,8 @@ struct TransformedTensorDescriptor
{
{
// not implemented
// not implemented
}
}
};
#endif
#endif
};
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
constexpr
auto
make_NativeTensorDescriptor
(
Sequence
<
Lengths
...
>
,
__host__
__device__
constexpr
auto
make_NativeTensorDescriptor
(
Sequence
<
Lengths
...
>
,
...
@@ -267,15 +292,28 @@ __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths.
...
@@ -267,15 +292,28 @@ __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths.
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
}
}
template
<
class
Lengths
>
template
<
typename
Lengths
>
__host__
__device__
constexpr
auto
make_NativeTensorDescriptor_packed
(
Lengths
)
__host__
__device__
constexpr
auto
make_NativeTensorDescriptor_packed
(
Lengths
)
{
{
constexpr
index_t
strides
=
reverse_inclusive_scan_sequence
(
constexpr
auto
strides
=
reverse_inclusive_scan_sequence
(
Lengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
Lengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
.
PushBack
(
Number
<
1
>
{});
return
make_NativeTensorDescriptor
(
Lengths
{},
strides
);
return
make_NativeTensorDescriptor
(
Lengths
{},
strides
);
}
}
template
<
typename
LowTensorDescriptor
,
typename
Transforms
,
typename
LowDimensionIds
,
typename
UpDimensionIds
>
__host__
__device__
constexpr
auto
transform_tensor_descriptor
(
LowTensorDescriptor
,
Transforms
,
LowDimensionIds
,
UpDimensionIds
)
{
return
TransformedTensorDescriptor
<
LowTensorDescriptor
,
Transforms
,
LowDimensionIds
,
UpDimensionIds
>
{};
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
View file @
7a7fe160
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
namespace
ck
{
namespace
ck
{
template
<
class
...
NativeDimensions
>
template
<
typename
...
NativeDimensions
>
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
NativeTensorDescriptor
<
NativeDimensions
...
>
desc
)
NativeTensorDescriptor
<
NativeDimensions
...
>
desc
)
{
{
...
...
composable_kernel/include/utility/Array.hpp
View file @
7a7fe160
...
@@ -6,48 +6,78 @@
...
@@ -6,48 +6,78 @@
namespace
ck
{
namespace
ck
{
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
struct
Array
struct
Array
{
{
using
T
ype
=
Array
<
TData
,
NSize
>
;
using
t
ype
=
Array
<
TData
,
NSize
>
;
using
data_type
=
TData
;
using
data_type
=
TData
;
static
constexpr
index_t
nSize
=
NSize
;
index_t
mData
[
NSize
]
;
index_t
mData
[
nSize
];
__host__
__device__
explicit
constexpr
Array
()
{}
template
<
class
...
Xs
>
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
Array
(
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
xs
)...}
__host__
__device__
explicit
constexpr
Array
(
X
x
,
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
x
),
static_cast
<
TData
>
(
xs
)...}
{
{
static_assert
(
sizeof
...(
Xs
)
+
1
==
NSize
,
"wrong! size"
);
}
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
NSize
;
}
#if 0
template <typename T>
__host__ __device__ explicit constexpr Array(const T& x)
{
static_assert(T::Size() == NSize, "wrong! size");
static_for<0, NSize, 1>{}([&](auto i){
mData[i] = x.At(i);
})
}
#endif
__host__
__device__
static
constexpr
index_t
Size
()
{
return
NSize
;
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
Size
();
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
TData
operator
[]
(
Number
<
I
>
)
const
__host__
__device__
constexpr
const
TData
&
At
(
Number
<
I
>
)
const
{
{
static_assert
(
I
<
NSize
,
"wrong!"
);
return
mData
[
I
];
return
mData
[
I
];
}
}
__host__
__device__
constexpr
TData
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
TData
&
operator
()
(
Number
<
I
>
)
__host__
__device__
constexpr
TData
&
At
(
Number
<
I
>
)
{
{
static_assert
(
I
<
NSize
,
"wrong!"
);
return
mData
[
I
];
return
mData
[
I
];
}
}
__host__
__device__
TData
&
operator
()
(
index_t
i
)
{
return
mData
[
i
];
}
__host__
__device__
constexpr
const
TData
&
At
(
index_t
i
)
const
{
return
mData
[
i
];
}
template
<
index_t
I
>
__host__
__device__
constexpr
TData
&
At
(
index_t
i
)
{
return
mData
[
i
];
}
__host__
__device__
constexpr
void
Set
(
Number
<
I
>
,
TData
x
)
template
<
typename
I
>
__host__
__device__
constexpr
const
TData
&
operator
[](
I
i
)
const
{
{
static_assert
(
I
<
NSize
,
"wrong!"
);
return
At
(
i
);
}
mData
[
I
]
=
x
;
template
<
typename
I
>
__host__
__device__
constexpr
TData
&
operator
()(
I
i
)
{
return
At
(
i
);
}
}
__host__
__device__
constexpr
void
Set
(
index_t
I
,
TData
x
)
{
mData
[
I
]
=
x
;
}
template
<
typename
T
>
__host__
__device__
constexpr
type
&
operator
=
(
const
T
&
x
)
{
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
x
[
i
];
});
return
*
this
;
}
struct
lambda_PushBack
// emulate constexpr lambda
struct
lambda_PushBack
// emulate constexpr lambda
{
{
...
@@ -63,7 +93,7 @@ struct Array
...
@@ -63,7 +93,7 @@ struct Array
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
void
operator
()(
Number
<
I
>
)
const
__host__
__device__
constexpr
void
operator
()(
Number
<
I
>
)
const
{
{
new_array
.
Set
(
Number
<
I
>
{}
,
old_array
[
I
]
)
;
new_array
(
Number
<
I
>
{}
)
=
old_array
[
I
];
}
}
};
};
...
@@ -73,71 +103,98 @@ struct Array
...
@@ -73,71 +103,98 @@ struct Array
static_for
<
0
,
NSize
,
1
>
{}(
lambda_PushBack
(
*
this
,
new_array
));
static_for
<
0
,
NSize
,
1
>
{}(
lambda_PushBack
(
*
this
,
new_array
));
new_array
.
Set
(
Number
<
NSize
>
{}
,
x
)
;
new_array
(
Number
<
NSize
>
{}
)
=
x
;
return
new_array
;
return
new_array
;
}
}
};
};
// A: Array
// A
rr
: Array
// Picks: Sequence<...>
// Picks: Sequence<...>
template
<
class
Arr
,
class
Picks
>
template
<
typename
Arr
,
typename
Picks
>
struct
ArrayElementPicker
struct
ArrayElementPicker
{
{
using
type
=
ArrayElementPicker
;
using
data_type
=
typename
Arr
::
data_type
;
using
data_type
=
typename
Arr
::
data_type
;
__host__
__device__
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mData
{
array
}
__host__
__device__
constexpr
ArrayElementPicker
()
=
delete
;
__host__
__device__
explicit
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mArray
{
array
}
{
{
constexpr
index_t
imax
=
constexpr
index_t
imax
=
accumulate_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
accumulate_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
static_assert
(
imax
<
Picks
::
Get
Size
(),
"wrong! exceeding
max id
"
);
static_assert
(
imax
<
Arr
::
Size
(),
"wrong! exceeding
# array element
"
);
}
}
__host__
__device__
static
constexpr
index_t
Get
Size
()
{
return
Picks
::
Get
Size
();
}
__host__
__device__
static
constexpr
auto
Size
()
{
return
Picks
::
Size
();
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
data_type
operator
[]
(
Number
<
I
>
)
const
__host__
__device__
constexpr
const
data_type
&
At
(
Number
<
I
>
)
const
{
{
constexpr
auto
IP
=
Picks
::
Get
(
Number
<
I
>
{});
static_assert
(
I
<
Size
(),
"wrong!"
);
return
mData
[
IP
];
constexpr
auto
IP
=
Picks
{}[
I
];
return
mArray
[
IP
];
}
}
__host__
__device__
constexpr
data_type
operator
[](
index_t
i
)
const
template
<
index_t
I
>
__host__
__device__
constexpr
data_type
&
At
(
Number
<
I
>
)
{
{
constexpr
index_t
ip
=
Picks
{}[
i
];
static_assert
(
I
<
Size
(),
"wrong!"
);
return
mData
[
ip
];
constexpr
auto
IP
=
Picks
{}[
I
];
return
mArray
(
IP
);
}
}
template
<
index_t
I
>
template
<
typename
I
>
__host__
__device__
data_type
&
operator
()(
Number
<
I
>
)
__host__
__device__
constexpr
const
data_type
&
operator
[](
I
i
)
const
{
{
constexpr
auto
IP
=
Picks
::
Get
(
Number
<
I
>
{});
return
At
(
i
);
return
mData
[
IP
];
}
}
__host__
__device__
data_type
&
operator
()(
index_t
i
)
template
<
typename
I
>
__host__
__device__
constexpr
data_type
&
operator
()(
I
i
)
{
{
constexpr
index_t
ip
=
Picks
{}[
i
];
return
At
(
i
);
return
mData
[
ip
];
}
}
Arr
&
mData
;
template
<
typename
T
>
__host__
__device__
constexpr
type
&
operator
=
(
const
T
&
a
)
{
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
a
[
i
];
});
return
*
this
;
}
Arr
&
mArray
;
};
};
template
<
class
Arr
,
class
Picks
>
template
<
typename
Arr
,
typename
Picks
>
__host__
__device__
constexpr
auto
pick_array_element
(
Arr
&
a
,
Picks
)
__host__
__device__
constexpr
auto
pick_array_element
(
Arr
&
a
,
Picks
)
{
{
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
}
}
#if 1
template
<
typename
T
>
__host__
__device__
constexpr
auto
to_array
(
const
T
&
x
)
{
Array
<
typename
T
::
data_type
,
T
::
Size
()
>
y
;
static_for
<
0
,
T
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
y
.
At
(
i
)
=
x
.
At
(
i
);
});
return
y
;
}
#endif
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
{
{
return
Array
<
index_t
,
sizeof
...(
Is
)
>
{
Is
...};
return
Array
<
index_t
,
sizeof
...(
Is
)
>
{
Is
...};
}
}
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
make_zero_array
()
__host__
__device__
constexpr
auto
make_zero_array
()
{
{
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
type
{};
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
type
{};
...
@@ -145,7 +202,7 @@ __host__ __device__ constexpr auto make_zero_array()
...
@@ -145,7 +202,7 @@ __host__ __device__ constexpr auto make_zero_array()
return
zero_array
;
return
zero_array
;
}
}
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
constexpr
auto
reorder_array_given_new2old
(
const
Array
<
TData
,
NSize
>&
old_array
,
__host__
__device__
constexpr
auto
reorder_array_given_new2old
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
/*new2old*/
)
Sequence
<
IRs
...
>
/*new2old*/
)
{
{
...
@@ -156,7 +213,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
...
@@ -156,7 +213,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return
Array
<
TData
,
NSize
>
{
old_array
[
IRs
]...};
return
Array
<
TData
,
NSize
>
{
old_array
[
IRs
]...};
}
}
template
<
class
TData
,
index_t
NSize
,
class
MapOld2New
>
template
<
typename
TData
,
index_t
NSize
,
typename
MapOld2New
>
struct
lambda_reorder_array_given_old2new
struct
lambda_reorder_array_given_old2new
{
{
const
Array
<
TData
,
NSize
>&
old_array
;
const
Array
<
TData
,
NSize
>&
old_array
;
...
@@ -173,13 +230,13 @@ struct lambda_reorder_array_given_old2new
...
@@ -173,13 +230,13 @@ struct lambda_reorder_array_given_old2new
{
{
TData
old_data
=
old_array
[
IOldDim
];
TData
old_data
=
old_array
[
IOldDim
];
constexpr
index_t
INewDim
=
MapOld2New
::
Ge
t
(
Number
<
IOldDim
>
{});
constexpr
index_t
INewDim
=
MapOld2New
::
A
t
(
Number
<
IOldDim
>
{});
new_array
.
Set
(
Number
<
INewDim
>
{}
,
old_data
)
;
new_array
(
Number
<
INewDim
>
{}
)
=
old_data
;
}
}
};
};
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
constexpr
auto
reorder_array_given_old2new
(
const
Array
<
TData
,
NSize
>&
old_array
,
__host__
__device__
constexpr
auto
reorder_array_given_old2new
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
/*old2new*/
)
Sequence
<
IRs
...
>
/*old2new*/
)
{
{
...
@@ -195,7 +252,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
...
@@ -195,7 +252,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return
new_array
;
return
new_array
;
}
}
template
<
class
TData
,
index_t
NSize
,
class
ExtractSeq
>
template
<
typename
TData
,
index_t
NSize
,
typename
ExtractSeq
>
__host__
__device__
constexpr
auto
extract_array
(
const
Array
<
TData
,
NSize
>&
old_array
,
ExtractSeq
)
__host__
__device__
constexpr
auto
extract_array
(
const
Array
<
TData
,
NSize
>&
old_array
,
ExtractSeq
)
{
{
Array
<
TData
,
ExtractSeq
::
GetSize
()
>
new_array
;
Array
<
TData
,
ExtractSeq
::
GetSize
()
>
new_array
;
...
@@ -204,12 +261,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
...
@@ -204,12 +261,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_assert
(
new_size
<=
NSize
,
"wrong! too many extract"
);
static_assert
(
new_size
<=
NSize
,
"wrong! too many extract"
);
static_for
<
0
,
new_size
,
1
>
{}([
&
](
auto
I
)
{
new_array
(
I
)
=
old_array
[
ExtractSeq
::
Ge
t
(
I
)];
});
static_for
<
0
,
new_size
,
1
>
{}([
&
](
auto
I
)
{
new_array
(
I
)
=
old_array
[
ExtractSeq
::
A
t
(
I
)];
});
return
new_array
;
return
new_array
;
}
}
template
<
class
F
,
class
X
,
class
Y
,
class
Z
>
// emulate constepxr lambda for array math
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
>
// emulate constepxr lambda for array
// math
struct
lambda_array_math
struct
lambda_array_math
{
{
const
F
&
f
;
const
F
&
f
;
...
@@ -226,13 +284,12 @@ struct lambda_array_math
...
@@ -226,13 +284,12 @@ struct lambda_array_math
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim_
>
)
const
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim_
>
)
const
{
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
z
(
IDim
)
=
f
(
x
[
IDim
],
y
[
IDim
]);
z
.
Set
(
IDim
,
f
(
x
[
IDim
],
y
[
IDim
]));
}
}
};
};
// Array = Array + Array
// Array = Array + Array
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Array
<
TData
,
NSize
>
b
)
{
{
Array
<
TData
,
NSize
>
result
;
Array
<
TData
,
NSize
>
result
;
...
@@ -247,7 +304,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
...
@@ -247,7 +304,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
}
}
// Array = Array - Array
// Array = Array - Array
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
-
(
Array
<
TData
,
NSize
>
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
-
(
Array
<
TData
,
NSize
>
a
,
Array
<
TData
,
NSize
>
b
)
{
{
Array
<
TData
,
NSize
>
result
;
Array
<
TData
,
NSize
>
result
;
...
@@ -262,7 +319,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
...
@@ -262,7 +319,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
}
}
// Array += Array
// Array += Array
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
+=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
+=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
{
{
a
=
a
+
b
;
a
=
a
+
b
;
...
@@ -270,14 +327,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
...
@@ -270,14 +327,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
}
}
// Array -= Array
// Array -= Array
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
-=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
-=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
{
{
a
=
a
-
b
;
a
=
a
-
b
;
return
a
;
return
a
;
}
}
// Array = Array + Sequence
// Array = Array + Sequence
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
{
{
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
...
@@ -294,7 +351,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
...
@@ -294,7 +351,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
}
}
// Array = Array - Sequence
// Array = Array - Sequence
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
-
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
__host__
__device__
constexpr
auto
operator
-
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
{
{
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
...
@@ -311,7 +368,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
...
@@ -311,7 +368,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
}
}
// Array = Array * Sequence
// Array = Array * Sequence
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
*
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
__host__
__device__
constexpr
auto
operator
*
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
{
{
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
...
@@ -328,7 +385,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
...
@@ -328,7 +385,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
}
}
// Array = Sequence - Array
// Array = Sequence - Array
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Is
...
>
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Is
...
>
a
,
Array
<
TData
,
NSize
>
b
)
{
{
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
...
@@ -344,7 +401,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
...
@@ -344,7 +401,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
return
result
;
return
result
;
}
}
template
<
class
TData
,
index_t
NSize
,
class
Reduce
>
template
<
typename
TData
,
index_t
NSize
,
typename
Reduce
>
__host__
__device__
constexpr
TData
__host__
__device__
constexpr
TData
accumulate_on_array
(
const
Array
<
TData
,
NSize
>&
a
,
Reduce
f
,
TData
init
)
accumulate_on_array
(
const
Array
<
TData
,
NSize
>&
a
,
Reduce
f
,
TData
init
)
{
{
...
@@ -357,89 +414,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
...
@@ -357,89 +414,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
return
result
;
return
result
;
}
}
template
<
class
T
,
index_t
NSize
>
__host__
__device__
void
print_Array
(
const
char
*
s
,
Array
<
T
,
NSize
>
a
)
{
constexpr
index_t
nsize
=
a
.
GetSize
();
static_assert
(
nsize
>
0
&&
nsize
<=
10
,
"wrong!"
);
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
a
[
0
]);
});
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
]);
});
static_if
<
nsize
==
3
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
]);
});
static_if
<
nsize
==
4
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
]);
});
static_if
<
nsize
==
5
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
]);
});
static_if
<
nsize
==
6
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
]);
});
static_if
<
nsize
==
7
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
]);
});
static_if
<
nsize
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
]);
});
static_if
<
nsize
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
]);
});
static_if
<
nsize
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
],
a
[
9
]);
});
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/Sequence.hpp
View file @
7a7fe160
...
@@ -12,22 +12,22 @@ struct static_for;
...
@@ -12,22 +12,22 @@ struct static_for;
template
<
index_t
...>
template
<
index_t
...>
struct
Sequence
;
struct
Sequence
;
template
<
class
Seq
,
index_t
I
>
template
<
typename
Seq
,
index_t
I
>
struct
sequence_split
;
struct
sequence_split
;
template
<
class
>
template
<
typename
>
struct
sequence_reverse
;
struct
sequence_reverse
;
template
<
class
>
template
<
typename
>
struct
sequence_map_inverse
;
struct
sequence_map_inverse
;
template
<
class
>
template
<
typename
>
struct
is_valid_sequence_map
;
struct
is_valid_sequence_map
;
template
<
index_t
I
,
index_t
...
Is
>
template
<
index_t
I
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence_pop_front
(
Sequence
<
I
,
Is
...
>
);
__host__
__device__
constexpr
auto
sequence_pop_front
(
Sequence
<
I
,
Is
...
>
);
template
<
class
Seq
>
template
<
typename
Seq
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
);
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
);
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
...
@@ -38,9 +38,11 @@ struct Sequence
...
@@ -38,9 +38,11 @@ struct Sequence
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
__host__
__device__
static
constexpr
auto
Get
Size
()
{
return
Number
<
mSize
>
{};
}
__host__
__device__
static
constexpr
auto
Size
()
{
return
Number
<
mSize
>
{};
}
__host__
__device__
static
constexpr
index_t
GetImpl
(
index_t
I
)
__host__
__device__
static
constexpr
auto
GetSize
()
{
return
Size
();
}
__host__
__device__
static
constexpr
index_t
At
(
index_t
I
)
{
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
...
@@ -48,23 +50,24 @@ struct Sequence
...
@@ -48,23 +50,24 @@ struct Sequence
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
Ge
t
(
Number
<
I
>
)
__host__
__device__
static
constexpr
auto
A
t
(
Number
<
I
>
)
{
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
return
Number
<
GetImpl
(
Number
<
I
>
{}
)
>
{};
return
Number
<
At
(
I
)
>
{};
}
}
__host__
__device__
static
constexpr
auto
Get
(
index_t
I
)
{
return
GetImpl
(
I
);
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[]
(
Number
<
I
>
)
const
__host__
__device__
static
constexpr
auto
Get
(
Number
<
I
>
)
{
{
return
Ge
t
(
Number
<
I
>
{});
return
A
t
(
Number
<
I
>
{});
}
}
// make sure I is constepxr if you want a constexpr return type
template
<
typename
I
>
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
{
return
GetImpl
(
I
);
}
__host__
__device__
constexpr
auto
operator
[](
I
i
)
const
{
return
At
(
i
);
}
template
<
index_t
...
IRs
>
template
<
index_t
...
IRs
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
...
@@ -74,14 +77,14 @@ struct Sequence
...
@@ -74,14 +77,14 @@ struct Sequence
static_assert
(
is_valid_sequence_map
<
Sequence
<
IRs
...
>>::
value
,
"wrong! invalid reorder map"
);
static_assert
(
is_valid_sequence_map
<
Sequence
<
IRs
...
>>::
value
,
"wrong! invalid reorder map"
);
return
Sequence
<
Type
::
Ge
t
(
Number
<
IRs
>
{})...
>
{};
return
Sequence
<
Type
::
A
t
(
Number
<
IRs
>
{})...
>
{};
}
}
// MapOld2New is Sequence<...>
// MapOld2New is Sequence<...>
template
<
class
MapOld2New
>
template
<
typename
MapOld2New
>
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
{
{
static_assert
(
MapOld2New
::
Get
Size
()
==
Get
Size
(),
static_assert
(
MapOld2New
::
Size
()
==
Size
(),
"wrong! reorder map should have the same size as Sequence to be rerodered"
);
"wrong! reorder map should have the same size as Sequence to be rerodered"
);
static_assert
(
is_valid_sequence_map
<
MapOld2New
>::
value
,
"wrong! invalid reorder map"
);
static_assert
(
is_valid_sequence_map
<
MapOld2New
>::
value
,
"wrong! invalid reorder map"
);
...
@@ -97,13 +100,13 @@ struct Sequence
...
@@ -97,13 +100,13 @@ struct Sequence
__host__
__device__
static
constexpr
auto
Front
()
__host__
__device__
static
constexpr
auto
Front
()
{
{
static_assert
(
mSize
>
0
,
"wrong!"
);
static_assert
(
mSize
>
0
,
"wrong!"
);
return
Ge
t
(
Number
<
0
>
{});
return
A
t
(
Number
<
0
>
{});
}
}
__host__
__device__
static
constexpr
auto
Back
()
__host__
__device__
static
constexpr
auto
Back
()
{
{
static_assert
(
mSize
>
0
,
"wrong!"
);
static_assert
(
mSize
>
0
,
"wrong!"
);
return
Ge
t
(
Number
<
mSize
-
1
>
{});
return
A
t
(
Number
<
mSize
-
1
>
{});
}
}
__host__
__device__
static
constexpr
auto
PopFront
()
{
return
sequence_pop_front
(
Type
{});
}
__host__
__device__
static
constexpr
auto
PopFront
()
{
return
sequence_pop_front
(
Type
{});
}
...
@@ -137,19 +140,19 @@ struct Sequence
...
@@ -137,19 +140,19 @@ struct Sequence
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
Ns
>
...)
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
Ns
>
...)
{
{
return
Sequence
<
Type
::
Ge
t
(
Number
<
Ns
>
{})...
>
{};
return
Sequence
<
Type
::
A
t
(
Number
<
Ns
>
{})...
>
{};
}
}
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
__host__
__device__
static
constexpr
auto
Extract
(
Sequence
<
Ns
...
>
)
__host__
__device__
static
constexpr
auto
Extract
(
Sequence
<
Ns
...
>
)
{
{
return
Sequence
<
Type
::
Ge
t
(
Number
<
Ns
>
{})...
>
{};
return
Sequence
<
Type
::
A
t
(
Number
<
Ns
>
{})...
>
{};
}
}
template
<
index_t
I
,
index_t
X
>
template
<
index_t
I
,
index_t
X
>
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
)
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
)
{
{
static_assert
(
I
<
Get
Size
(),
"wrong!"
);
static_assert
(
I
<
Size
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
Type
,
I
>
;
using
seq_split
=
sequence_split
<
Type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
SeqType0
{};
constexpr
auto
seq_left
=
typename
seq_split
::
SeqType0
{};
...
@@ -158,7 +161,7 @@ struct Sequence
...
@@ -158,7 +161,7 @@ struct Sequence
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
}
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
static
constexpr
auto
Transform
(
F
f
)
__host__
__device__
static
constexpr
auto
Transform
(
F
f
)
{
{
return
Sequence
<
f
(
Is
)...
>
{};
return
Sequence
<
f
(
Is
)...
>
{};
...
@@ -166,8 +169,11 @@ struct Sequence
...
@@ -166,8 +169,11 @@ struct Sequence
};
};
// merge sequence
// merge sequence
template
<
class
,
class
>
template
<
typename
Seq
,
typename
...
Seqs
>
struct
sequence_merge
;
struct
sequence_merge
{
using
type
=
typename
sequence_merge
<
Seq
,
typename
sequence_merge
<
Seqs
...
>::
type
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
struct
sequence_merge
<
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
struct
sequence_merge
<
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
...
@@ -175,8 +181,14 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
...
@@ -175,8 +181,14 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using
type
=
Sequence
<
Xs
...,
Ys
...
>
;
using
type
=
Sequence
<
Xs
...,
Ys
...
>
;
};
};
template
<
typename
Seq
>
struct
sequence_merge
<
Seq
>
{
using
type
=
Seq
;
};
// generate sequence
// generate sequence
template
<
index_t
IBegin
,
index_t
NRemain
,
class
F
>
template
<
index_t
IBegin
,
index_t
NRemain
,
typename
F
>
struct
sequence_gen_impl
struct
sequence_gen_impl
{
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
...
@@ -188,20 +200,20 @@ struct sequence_gen_impl
...
@@ -188,20 +200,20 @@ struct sequence_gen_impl
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
F
>::
type
>::
type
;
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
F
>::
type
>::
type
;
};
};
template
<
index_t
I
,
class
F
>
template
<
index_t
I
,
typename
F
>
struct
sequence_gen_impl
<
I
,
1
,
F
>
struct
sequence_gen_impl
<
I
,
1
,
F
>
{
{
static
constexpr
index_t
Is
=
F
{}(
Number
<
I
>
{});
static
constexpr
index_t
Is
=
F
{}(
Number
<
I
>
{});
using
type
=
Sequence
<
Is
>
;
using
type
=
Sequence
<
Is
>
;
};
};
template
<
index_t
I
,
class
F
>
template
<
index_t
I
,
typename
F
>
struct
sequence_gen_impl
<
I
,
0
,
F
>
struct
sequence_gen_impl
<
I
,
0
,
F
>
{
{
using
type
=
Sequence
<>
;
using
type
=
Sequence
<>
;
};
};
template
<
index_t
NSize
,
class
F
>
template
<
index_t
NSize
,
typename
F
>
struct
sequence_gen
struct
sequence_gen
{
{
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
...
@@ -235,10 +247,10 @@ struct uniform_sequence_gen
...
@@ -235,10 +247,10 @@ struct uniform_sequence_gen
};
};
// reverse inclusive scan (with init) sequence
// reverse inclusive scan (with init) sequence
template
<
class
,
class
,
index_t
>
template
<
typename
,
typename
,
index_t
>
struct
sequence_reverse_inclusive_scan
;
struct
sequence_reverse_inclusive_scan
;
template
<
index_t
I
,
index_t
...
Is
,
class
Reduce
,
index_t
Init
>
template
<
index_t
I
,
index_t
...
Is
,
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
,
Is
...
>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
,
Is
...
>
,
Reduce
,
Init
>
{
{
using
old_scan
=
typename
sequence_reverse_inclusive_scan
<
Sequence
<
Is
...
>
,
Reduce
,
Init
>::
type
;
using
old_scan
=
typename
sequence_reverse_inclusive_scan
<
Sequence
<
Is
...
>
,
Reduce
,
Init
>::
type
;
...
@@ -248,23 +260,23 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
...
@@ -248,23 +260,23 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
using
type
=
typename
sequence_merge
<
Sequence
<
new_reduce
>
,
old_scan
>::
type
;
using
type
=
typename
sequence_merge
<
Sequence
<
new_reduce
>
,
old_scan
>::
type
;
};
};
template
<
index_t
I
,
class
Reduce
,
index_t
Init
>
template
<
index_t
I
,
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
>
,
Reduce
,
Init
>
{
{
using
type
=
Sequence
<
Reduce
{}(
I
,
Init
)
>
;
using
type
=
Sequence
<
Reduce
{}(
I
,
Init
)
>
;
};
};
template
<
class
Reduce
,
index_t
Init
>
template
<
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<>
,
Reduce
,
Init
>
{
{
using
type
=
Sequence
<>
;
using
type
=
Sequence
<>
;
};
};
// split sequence
// split sequence
template
<
class
Seq
,
index_t
I
>
template
<
typename
Seq
,
index_t
I
>
struct
sequence_split
struct
sequence_split
{
{
static
constexpr
index_t
NSize
=
Seq
{}.
Get
Size
();
static
constexpr
index_t
NSize
=
Seq
{}.
Size
();
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
...
@@ -274,10 +286,10 @@ struct sequence_split
...
@@ -274,10 +286,10 @@ struct sequence_split
};
};
// reverse sequence
// reverse sequence
template
<
class
Seq
>
template
<
typename
Seq
>
struct
sequence_reverse
struct
sequence_reverse
{
{
static
constexpr
index_t
NSize
=
Seq
{}.
Get
Size
();
static
constexpr
index_t
NSize
=
Seq
{}.
Size
();
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
type
=
typename
sequence_merge
<
using
type
=
typename
sequence_merge
<
...
@@ -297,19 +309,102 @@ struct sequence_reverse<Sequence<I0, I1>>
...
@@ -297,19 +309,102 @@ struct sequence_reverse<Sequence<I0, I1>>
using
type
=
Sequence
<
I1
,
I0
>
;
using
type
=
Sequence
<
I1
,
I0
>
;
};
};
template
<
class
Seq
,
class
Compare
>
template
<
typename
Seq
,
typename
Compare
>
struct
sequence_sort
struct
sequence_sort
{
{
// not implemented
template
<
typename
SeqLeft
,
typename
SeqRight
,
typename
MergedSeq
,
typename
Comp
>
struct
sorted_sequence_merge_impl
{
static
constexpr
bool
pick_left
=
SeqLeft
::
Front
()
<
SeqRight
::
Front
();
static
constexpr
index_t
next_value
=
pick_left
?
SeqLeft
::
Front
()
:
SeqRight
::
Front
();
using
new_merged_seq
=
decltype
(
MergedSeq
::
PushBack
(
Number
<
next_value
>
{}));
using
new_left_seq
=
typename
conditional
<
pick_left
,
decltype
(
SeqLeft
::
PopFront
()),
SeqLeft
>::
type
;
using
new_right_seq
=
typename
conditional
<
pick_left
,
SeqRight
,
decltype
(
SeqRight
::
PopFront
())
>::
type
;
using
type
=
typename
sorted_sequence_merge_impl
<
new_left_seq
,
new_right_seq
,
new_merged_seq
,
Comp
>::
type
;
};
template
<
typename
SeqLeft
,
typename
MergedSeq
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
SeqLeft
,
Sequence
<>
,
MergedSeq
,
Comp
>
{
using
type
=
typename
sequence_merge
<
MergedSeq
,
SeqLeft
>::
type
;
};
template
<
typename
SeqRight
,
typename
MergedSeq
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
Sequence
<>
,
SeqRight
,
MergedSeq
,
Comp
>
{
using
type
=
typename
sequence_merge
<
MergedSeq
,
SeqRight
>::
type
;
};
template
<
typename
Seq0
,
typename
Seq1
,
typename
Comp
>
struct
sorted_sequence_merge
{
using
type
=
typename
sorted_sequence_merge_impl
<
Seq0
,
Seq1
,
Sequence
<>
,
Comp
>::
type
;
};
using
split
=
sequence_split
<
Seq
,
Seq
::
Size
()
/
2
>
;
using
unsorted_left
=
typename
split
::
SeqType0
;
using
unsorted_right
=
typename
split
::
SeqType1
;
using
sorted_left
=
typename
sequence_sort
<
unsorted_left
,
Compare
>::
type
;
using
sorted_right
=
typename
sequence_sort
<
unsorted_right
,
Compare
>::
type
;
using
type
=
typename
sorted_sequence_merge
<
sorted_left
,
sorted_right
,
Compare
>::
type
;
};
template
<
index_t
X
,
index_t
Y
,
typename
Compare
>
struct
sequence_sort
<
Sequence
<
X
,
Y
>
,
Compare
>
{
static
constexpr
bool
x_first
=
Compare
{}(
X
,
Y
);
using
type
=
typename
conditional
<
x_first
,
Sequence
<
X
,
Y
>
,
Sequence
<
Y
,
X
>>::
type
;
};
};
template
<
class
Seq
,
class
Compare
>
template
<
index_t
X
,
typename
Compare
>
struct
sequence_sort
<
Sequence
<
X
>
,
Compare
>
{
using
type
=
Sequence
<
X
>
;
};
template
<
typename
Seq
,
typename
Less
,
typename
Equal
>
struct
sequence_unique_sort
struct
sequence_unique_sort
{
{
// not implemented
template
<
typename
WorkInputSeq
,
typename
WorkOutputSeq
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
{
static
constexpr
index_t
new_value
=
WorkInputSeq
::
Front
();
using
new_work_input_seq
=
decltype
(
WorkInputSeq
::
PopFront
());
using
new_working_output_seq
=
typename
conditional
<
new_value
==
WorkOutputSeq
::
Back
(),
WorkOutputSeq
,
decltype
(
WorkOutputSeq
::
PopBack
(
Number
<
new_value
>
{}))
>::
type
;
};
template
<
typename
WorkInputSeq
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
<
WorkInputSeq
,
Sequence
<>
,
Eq
>
{
using
type
=
WorkInputSeq
;
};
template
<
typename
SortedSeq
,
typename
Eq
>
struct
sorted_sequence_uniquify
{
using
type
=
typename
sorted_sequence_uniquify_impl
<
SortedSeq
,
Sequence
<>
,
Eq
>::
type
;
};
using
sorted_seq
=
typename
sequence_sort
<
Seq
,
Less
>::
type
;
using
type
=
typename
sorted_sequence_uniquify
<
sorted_seq
,
Equal
>::
type
;
};
};
template
<
class
Seq
>
template
<
typename
Seq
>
struct
is_valid_sequence_map
struct
is_valid_sequence_map
{
{
// not implemented yet, always return true
// not implemented yet, always return true
...
@@ -317,36 +412,35 @@ struct is_valid_sequence_map
...
@@ -317,36 +412,35 @@ struct is_valid_sequence_map
// TODO: add proper check for is_valid, something like:
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::
Get
Size(), 1>::type,
// is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{};
// typename sequence_sort<Seq>::SortedSeqType>{};
};
};
template
<
class
X2Y
,
class
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
struct
sequence_map_inverse_impl
struct
sequence_map_inverse_impl
{
{
private:
private:
static
constexpr
auto
new_y2x
=
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
::
At
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
WorkingY2X
::
Modify
(
X2Y
::
Get
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
public:
public:
using
type
=
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
};
};
template
<
class
X2Y
,
class
WorkingY2X
,
index_t
XBegin
>
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
{
{
using
type
=
WorkingY2X
;
using
type
=
WorkingY2X
;
};
};
template
<
class
X2Y
>
template
<
typename
X2Y
>
struct
sequence_map_inverse
struct
sequence_map_inverse
{
{
using
type
=
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
typename
sequence_map_inverse_impl
<
X2Y
,
typename
uniform_sequence_gen
<
X2Y
::
Get
Size
(),
0
>::
type
,
typename
uniform_sequence_gen
<
X2Y
::
Size
(),
0
>::
type
,
0
,
0
,
X2Y
::
Get
Size
()
>::
type
;
X2Y
::
Size
()
>::
type
;
};
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
...
@@ -457,20 +551,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
...
@@ -457,20 +551,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return
Sequence
<
Is
...
>
{};
return
Sequence
<
Is
...
>
{};
}
}
template
<
class
Seq
>
template
<
typename
Seq
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
)
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
)
{
{
static_assert
(
Seq
::
Get
Size
()
>
0
,
"wrong! cannot pop an empty Sequence!"
);
static_assert
(
Seq
::
Size
()
>
0
,
"wrong! cannot pop an empty Sequence!"
);
return
sequence_pop_front
(
Seq
::
Reverse
()).
Reverse
();
return
sequence_pop_front
(
Seq
::
Reverse
()).
Reverse
();
}
}
template
<
class
F
,
index_t
...
Xs
>
template
<
typename
F
,
index_t
...
Xs
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
)
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
)
{
{
return
Sequence
<
f
(
Xs
)...
>
{};
return
Sequence
<
f
(
Xs
)...
>
{};
}
}
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
>
template
<
typename
...
Seqs
>
__host__
__device__
constexpr
auto
merge_sequences
(
Seqs
...)
{
return
typename
sequence_merge
<
Seqs
...
>::
type
{};
}
template
<
typename
F
,
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
static_assert
(
Sequence
<
Xs
...
>::
mSize
==
Sequence
<
Ys
...
>::
mSize
,
"Dim not the same"
);
static_assert
(
Sequence
<
Xs
...
>::
mSize
==
Sequence
<
Ys
...
>::
mSize
,
"Dim not the same"
);
...
@@ -478,7 +578,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
...
@@ -478,7 +578,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
}
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
template
<
typename
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
Sequence
<
Zs
...
>
)
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
Sequence
<
Zs
...
>
)
{
{
...
@@ -489,19 +589,19 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
...
@@ -489,19 +589,19 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
}
}
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
auto
reverse_inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
__host__
__device__
constexpr
auto
reverse_inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
{
{
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
,
Init
>::
type
{};
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
,
Init
>::
type
{};
}
}
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
auto
inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
__host__
__device__
constexpr
auto
inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
{
{
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
}
}
template
<
class
Seq
,
class
Reduce
>
template
<
typename
Seq
,
typename
Reduce
>
struct
lambda_accumulate_on_sequence
struct
lambda_accumulate_on_sequence
{
{
const
Reduce
&
f
;
const
Reduce
&
f
;
...
@@ -512,14 +612,14 @@ struct lambda_accumulate_on_sequence
...
@@ -512,14 +612,14 @@ struct lambda_accumulate_on_sequence
{
{
}
}
template
<
class
IDim
>
template
<
typename
IDim
>
__host__
__device__
constexpr
index_t
operator
()(
IDim
)
const
__host__
__device__
constexpr
index_t
operator
()(
IDim
)
const
{
{
return
result
=
f
(
result
,
Seq
::
Ge
t
(
IDim
{}));
return
result
=
f
(
result
,
Seq
::
A
t
(
IDim
{}));
}
}
};
};
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
index_t
__host__
__device__
constexpr
index_t
accumulate_on_sequence
(
Seq
,
Reduce
f
,
Number
<
Init
>
/*initial_value*/
)
accumulate_on_sequence
(
Seq
,
Reduce
f
,
Number
<
Init
>
/*initial_value*/
)
{
{
...
@@ -530,41 +630,5 @@ accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
...
@@ -530,41 +630,5 @@ accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
return
result
;
return
result
;
}
}
template
<
index_t
...
Xs
>
__host__
__device__
void
print_Sequence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
{
constexpr
index_t
nsize
=
Sequence
<
Xs
...
>::
GetSize
();
static_assert
(
nsize
<=
10
,
"wrong!"
);
static_if
<
nsize
==
0
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
3
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
4
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
5
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
6
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
7
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
8
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
9
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
10
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/array_helper.hpp
0 → 100644
View file @
7a7fe160
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "Array.hpp"
namespace
ck
{
template
<
typename
T
,
index_t
NSize
>
__host__
__device__
void
print_Array
(
const
char
*
s
,
Array
<
T
,
NSize
>
a
)
{
constexpr
index_t
nsize
=
a
.
GetSize
();
static_assert
(
nsize
>
0
&&
nsize
<=
10
,
"wrong!"
);
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
a
[
0
]);
});
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
]);
});
static_if
<
nsize
==
3
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
]);
});
static_if
<
nsize
==
4
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
]);
});
static_if
<
nsize
==
5
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
]);
});
static_if
<
nsize
==
6
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
]);
});
static_if
<
nsize
==
7
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
]);
});
static_if
<
nsize
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
]);
});
static_if
<
nsize
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
]);
});
static_if
<
nsize
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
],
a
[
9
]);
});
}
}
// namespace ck
#endif
\ No newline at end of file
composable_kernel/include/utility/common_header.hpp
View file @
7a7fe160
...
@@ -4,14 +4,19 @@
...
@@ -4,14 +4,19 @@
#include "config.hpp"
#include "config.hpp"
#include "utility.hpp"
#include "utility.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
#include "math.hpp"
#include "math.hpp"
#include "vector_type.hpp"
#include "vector_type.hpp"
#include "Sequence.hpp"
#include "Sequence.hpp"
#include "sequence_helper.hpp"
#include "Array.hpp"
#include "Array.hpp"
#include "array_helper.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#if CK_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp"
#include "amd_inline_asm.hpp"
...
...
composable_kernel/include/utility/functional.hpp
View file @
7a7fe160
...
@@ -3,9 +3,11 @@
...
@@ -3,9 +3,11 @@
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "Sequence.hpp"
#include "Sequence.hpp"
#include "type.hpp"
namespace
ck
{
namespace
ck
{
// TODO: right? wrong?
struct
forwarder
struct
forwarder
{
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -17,7 +19,7 @@ struct forwarder
...
@@ -17,7 +19,7 @@ struct forwarder
struct
swallow
struct
swallow
{
{
template
<
class
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
constexpr
swallow
(
Ts
&&
...)
__host__
__device__
constexpr
swallow
(
Ts
&&
...)
{
{
}
}
...
@@ -32,7 +34,7 @@ struct static_if<true>
...
@@ -32,7 +34,7 @@ struct static_if<true>
{
{
using
Type
=
static_if
<
true
>
;
using
Type
=
static_if
<
true
>
;
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
constexpr
auto
operator
()(
F
f
)
const
__host__
__device__
constexpr
auto
operator
()(
F
f
)
const
{
{
// This is a trick for compiler:
// This is a trick for compiler:
...
@@ -43,7 +45,7 @@ struct static_if<true>
...
@@ -43,7 +45,7 @@ struct static_if<true>
return
Type
{};
return
Type
{};
}
}
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
static
constexpr
auto
Else
(
F
)
__host__
__device__
static
constexpr
auto
Else
(
F
)
{
{
return
Type
{};
return
Type
{};
...
@@ -55,13 +57,13 @@ struct static_if<false>
...
@@ -55,13 +57,13 @@ struct static_if<false>
{
{
using
Type
=
static_if
<
false
>
;
using
Type
=
static_if
<
false
>
;
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
constexpr
auto
operator
()(
F
)
const
__host__
__device__
constexpr
auto
operator
()(
F
)
const
{
{
return
Type
{};
return
Type
{};
}
}
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
static
constexpr
auto
Else
(
F
f
)
__host__
__device__
static
constexpr
auto
Else
(
F
f
)
{
{
// This is a trick for compiler:
// This is a trick for compiler:
...
@@ -73,5 +75,23 @@ struct static_if<false>
...
@@ -73,5 +75,23 @@ struct static_if<false>
}
}
};
};
template
<
bool
predicate
,
class
X
,
class
Y
>
struct
conditional
;
template
<
class
X
,
class
Y
>
struct
conditional
<
true
,
X
,
Y
>
{
using
type
=
X
;
};
template
<
class
X
,
class
Y
>
struct
conditional
<
false
,
X
,
Y
>
{
using
type
=
Y
;
};
template
<
bool
predicate
,
class
X
,
class
Y
>
using
conditional_t
=
typename
conditional
<
predicate
,
X
,
Y
>::
type
;
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/functional2.hpp
View file @
7a7fe160
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
namespace
ck
{
namespace
ck
{
namespace
detail
{
template
<
class
>
template
<
class
>
struct
static_for_impl
;
struct
static_for_impl
;
...
@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
...
@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
}
}
};
};
}
// namespace detail
// F signature: F(Number<Iter>)
// F signature: F(Number<Iter>)
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
struct
static_for
struct
static_for
...
@@ -33,7 +37,8 @@ struct static_for
...
@@ -33,7 +37,8 @@ struct static_for
template
<
class
F
>
template
<
class
F
>
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
{
{
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
type
>
{}(
f
);
detail
::
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
type
>
{}(
f
);
}
}
};
};
...
...
composable_kernel/include/utility/functional3.hpp
View file @
7a7fe160
...
@@ -8,20 +8,7 @@
...
@@ -8,20 +8,7 @@
namespace
ck
{
namespace
ck
{
template
<
class
>
namespace
detail
{
struct
is_static
:
integral_constant
<
bool
,
false
>
{
};
template
<
class
T
,
T
X
>
struct
is_static
<
integral_constant
<
T
,
X
>>
:
integral_constant
<
bool
,
true
>
{
};
template
<
index_t
...
Is
>
struct
is_static
<
Sequence
<
Is
...
>>
:
integral_constant
<
bool
,
true
>
{
};
// RemainLengths: Sequence<...>
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
// Orders: Sequence<...>
...
@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
...
@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
}
}
};
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template
<
class
Lengths
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
GetSize
(),
1
>
::
type
>
struct
static_ford
{
__host__
__device__
constexpr
static_ford
()
{
static_assert
(
Lengths
::
GetSize
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
GetSize
()
==
Orders
::
GetSize
(),
"wrong! inconsistent size"
);
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template
<
class
F
>
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
ReorderGivenNew2Old
(
Orders
{});
static_ford_impl
<
decltype
(
ordered_lengths
),
Orders
>
{}(
f
,
Sequence
<>
{});
}
};
// RemainLengths: Sequence<...>
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
// Orders: Sequence<...>
template
<
class
RemainLengths
,
class
Orders
>
template
<
class
RemainLengths
,
class
Orders
>
...
@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
...
@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
}
}
};
};
}
// namespace detail
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template
<
class
Lengths
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
GetSize
(),
1
>
::
type
>
struct
static_ford
{
__host__
__device__
constexpr
static_ford
()
{
static_assert
(
Lengths
::
GetSize
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
GetSize
()
==
Orders
::
GetSize
(),
"wrong! inconsistent size"
);
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template
<
class
F
>
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
ReorderGivenNew2Old
(
Orders
{});
detail
::
static_ford_impl
<
decltype
(
ordered_lengths
),
Orders
>
{}(
f
,
Sequence
<>
{});
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension
// dimension
...
@@ -139,7 +128,8 @@ struct ford
...
@@ -139,7 +128,8 @@ struct ford
for
(
index_t
i
=
0
;
i
<
ordered_lengths
.
Front
();
++
i
)
for
(
index_t
i
=
0
;
i
<
ordered_lengths
.
Front
();
++
i
)
{
{
ford_impl
<
decltype
(
ordered_lengths
.
PopFront
()),
Orders
>
{}(
f
,
Array
<
index_t
,
1
>
{
i
});
detail
::
ford_impl
<
decltype
(
ordered_lengths
.
PopFront
()),
Orders
>
{}(
f
,
Array
<
index_t
,
1
>
{
i
});
}
}
}
}
};
};
...
...
composable_kernel/include/utility/functional4.hpp
0 → 100644
View file @
7a7fe160
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "Sequence.hpp"
#include "tuple.hpp"
#include "Array.hpp"
namespace
ck
{
namespace
detail
{
template
<
typename
Indices
>
struct
unpack_impl
;
template
<
index_t
...
Is
>
struct
unpack_impl
<
Sequence
<
Is
...
>>
{
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
const
X
&
x
)
const
{
return
f
(
x
.
At
(
Number
<
Is
>
{})...);
}
};
}
// namespace detail
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
unpack
(
F
f
,
const
X
&
x
)
{
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
>
{}(
f
,
x
);
}
}
// namespace ck
#endif
composable_kernel/include/utility/integral_constant.hpp
View file @
7a7fe160
...
@@ -13,54 +13,5 @@ struct integral_constant
...
@@ -13,54 +13,5 @@ struct integral_constant
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
};
};
template
<
class
X
,
class
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
class
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
class
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
+
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
<=
X
,
"wrong!"
);
return
Number
<
X
-
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
*
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
*
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
/
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
/
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
%
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
%
Y
>
{};
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/math.hpp
View file @
7a7fe160
...
@@ -104,6 +104,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
...
@@ -104,6 +104,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
return
max
(
x
,
xs
...);
return
max
(
x
,
xs
...);
}
}
template
<
class
T
>
struct
equal
{
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
==
y
;
}
};
template
<
class
T
>
struct
less
{
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
<
y
;
}
};
}
// namespace math
}
// namespace math
}
// namspace ck
}
// namspace ck
...
...
composable_kernel/include/utility/number.hpp
0 → 100644
View file @
7a7fe160
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
#include "integral_constant.hpp"
namespace
ck
{
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
+
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
<=
X
,
"wrong!"
);
return
Number
<
X
-
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
*
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
*
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
/
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
/
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
%
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
%
Y
>
{};
}
}
// namespace ck
#endif
composable_kernel/include/utility/sequence_helper.hpp
0 → 100644
View file @
7a7fe160
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "Sequence.hpp"
namespace
ck
{
template
<
index_t
...
Xs
>
__host__
__device__
void
print_Sequence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
{
constexpr
index_t
nsize
=
Sequence
<
Xs
...
>::
Size
();
static_assert
(
nsize
<=
10
,
"wrong!"
);
static_if
<
nsize
==
0
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
3
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
4
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
5
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
6
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
7
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
8
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
9
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
10
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
}
}
// namespace ck
#endif
composable_kernel/include/utility/tuple.hpp
View file @
7a7fe160
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define CK_TUPLE_HPP
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
#include "Sequence.hpp"
#include "Sequence.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -16,6 +17,8 @@ struct TupleElementKey
...
@@ -16,6 +17,8 @@ struct TupleElementKey
template
<
typename
Key
,
typename
Data
>
template
<
typename
Key
,
typename
Data
>
struct
TupleElement
struct
TupleElement
{
{
__host__
__device__
explicit
constexpr
TupleElement
()
:
mData
()
{}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
explicit
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
static_cast
<
T
&&>
(
v
))
__host__
__device__
explicit
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
static_cast
<
T
&&>
(
v
))
{
{
...
@@ -48,6 +51,12 @@ struct TupleImpl;
...
@@ -48,6 +51,12 @@ struct TupleImpl;
template
<
index_t
...
Is
,
typename
...
Xs
>
template
<
index_t
...
Is
,
typename
...
Xs
>
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
...
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
...
{
{
#if 1
__host__
__device__
explicit
constexpr
TupleImpl
()
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
()...
{
}
#endif
template
<
typename
...
Ys
>
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
TupleImpl
(
Ys
&&
...
ys
)
__host__
__device__
explicit
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
static_cast
<
Ys
&&>
(
ys
))...
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
static_cast
<
Ys
&&>
(
ys
))...
...
@@ -97,5 +106,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -97,5 +106,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
}
}
};
};
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cv_t
<
remove_reference_t
<
Xs
>>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
}
namespace
detail
{
template
<
typename
X
,
typename
F
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
transpose_tuple_impl
(
X
&
x
,
F
f
,
Sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}))...);
}
}
// namespace detail
template
<
typename
X
,
typename
F
>
__host__
__device__
constexpr
auto
transpose_tuple
(
X
&
x
,
F
f
)
{
return
detail
::
transpose_tuple_impl
(
x
,
f
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/type.hpp
0 → 100644
View file @
7a7fe160
#ifndef CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
#include "Sequence.hpp"
namespace
ck
{
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
typename
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
>
struct
is_static
:
integral_constant
<
bool
,
false
>
{
};
template
<
typename
T
,
T
X
>
struct
is_static
<
integral_constant
<
T
,
X
>>
:
integral_constant
<
bool
,
true
>
{
};
template
<
index_t
...
Is
>
struct
is_static
<
Sequence
<
Is
...
>>
:
integral_constant
<
bool
,
true
>
{
};
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
}
// namespace ck
#endif
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
View file @
7a7fe160
...
@@ -115,8 +115,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
...
@@ -115,8 +115,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
4
;
#endif
#endif
#if 0 // debug
constexpr index_t GridSize =
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
#else
constexpr
index_t
GridSize
=
1
;
#endif
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment