Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
0c05f427
Commit
0c05f427
authored
Sep 05, 2019
by
Chao Liu
Browse files
adding dimension tranformation
parent
bd44e639
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
255 additions
and
277 deletions
+255
-277
composable_kernel/include/tensor_description/dimension_transform.hpp
...kernel/include/tensor_description/dimension_transform.hpp
+0
-217
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+155
-31
composable_kernel/include/utility/Array.hpp
composable_kernel/include/utility/Array.hpp
+50
-0
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+44
-0
composable_kernel/include/utility/functional2.hpp
composable_kernel/include/utility/functional2.hpp
+0
-29
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+6
-0
No files found.
composable_kernel/include/tensor_description/dimension_transform.hpp
deleted
100644 → 0
View file @
bd44e639
#ifndef CK_DIMENSION_TRANSFORM_HPP
#define CK_DIMENSION_TRANSFORM_HPP
#include "common_header.hpp"
namespace
ck
{
template
<
index_t
N
>
using
MultiIndex
=
Array
<
index_t
,
N
>
;
// LowLengths: Sequence<...>
template
<
class
LowLengths
>
struct
PassThrough
{
static
constexpr
index_t
nDim
=
LowLengths
::
GetSize
();
using
LowerId
=
MultiIndex
<
nDim
>
;
using
UpperId
=
LowerId
;
__host__
__device__
static
constexpr
auto
GetLowerNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperNumOfDimension
()
{
return
GetLowerNumOfDimension
();
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
GetLowerLengths
();
}
__host__
__device__
static
constexpr
auto
GetLowerId
(
UpperId
id_up
)
{
return
id_up
;
}
__host__
__device__
static
constexpr
auto
GetLowerIdDiff
(
UpperId
id_up_diff
)
{
return
id_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
};
// LowLengths: Sequence<...>
template
<
class
LowLengths
,
class
LeftPads
,
class
RightPads
>
struct
Pad
{
static
constexpr
index_t
nDim
=
LowLengths
::
GetSize
();
using
LowerId
=
MultiIndex
<
nDim
>
;
using
UpperId
=
LowerId
;
__host__
__device__
static
constexpr
auto
GetLowerNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperNumOfDimension
()
{
return
GetLowerNumOfDimension
();
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
GetLowerLengths
()
+
LeftPads
+
RightPads
;
}
__host__
__device__
static
constexpr
auto
GetLowerId
(
UpperId
id_up
)
{
return
id_up
-
LeftPads
;
}
__host__
__device__
static
constexpr
auto
GetLowerIdDiff
(
UpperId
id_up_diff
)
{
return
id_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
};
// LowLengths: Sequence<...>
template
<
class
LowLengths
>
struct
Merge
{
static
constexpr
index_t
nDimLow
=
LowLengths
::
GetSize
();
static
constexpr
index_t
nDimUp
=
1
;
using
LowerId
=
MultiIndex
<
nDimLow
>
;
using
UpperId
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
static
constexpr
auto
GetUpperNumOfDimension
(){
return
Number
<
nDimUp
>
{}};
__host__
__device__
static
constexpr
auto
GetLowerNumOfDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
accumulate_on_sequence
(
GetLowerLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerId
(
UpperId
id_up
)
{
LowerId
id_low
;
// not implemeneted
return
id_low
;
}
// id_low_diff depends on id_low_old, so id_low need to be up-to-date
__host__
__device__
static
constexpr
auto
GetLowerIdDiff
(
UpperId
id_up_diff
,
LowerId
id_low_old
)
{
LowerId
id_low_diff
;
// not implemeneted
return
id_low_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
};
// UpLengths: Sequence<...>
template
<
index_t
LowLength
,
class
UpLengths
>
struct
Unmerge
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpLengths
::
GetSize
();
__host__
__device__
constexpr
Unmerge
()
{
static_assert
(
LowLength
==
accumulate_on_sequence
(
UpLengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{}),
"wrong! UpLengths need to be "
);
}
__host__
__device__
static
constexpr
auto
GetUpperNumOfDimension
(){
return
Number
<
nDimUp
>
{}};
__host__
__device__
static
constexpr
auto
GetLowerNumOfDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
Sequence
<
LowLength
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
GetLowerId
(
UpperId
id_up
)
{
constexpr
auto
scans
=
typename
sequence_reverse_inclusive_scan
<
UpLengths
,
math
::
multiplies
<
index_t
>
,
1
>::
type
{};
LowerId
id_low
{
0
};
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
id_low
[
0
]
+=
id_up
[
idim
]
*
scans
[
idim
];
});
return
id_low
;
}
__host__
__device__
static
constexpr
auto
GetLowerIdDiff
(
UpperId
id_up_diff
)
{
return
GetLowerId
(
id_up_diff
);
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
};
// UpLengths: Sequence<...>
// Coefficients: Sequence<...>
// id_low = coefficients[0, ...nDimUp-1] * id_up[0, ...nDimUp-1] + coefficients[nDimUp]
template
<
index_t
LowLength
,
class
UpLengths
,
class
Coefficients
>
struct
Embed
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpLengths
::
GetSize
();
static
constexpr
auto
mCoefficients
=
Coefficients
{};
__host__
__device__
constexpr
Embed
()
{
static_assert
(
UpLengths
::
GetSize
()
==
nDimUp
&&
Coefficients
::
GetSize
()
==
nDimUp
+
1
,
"wrong! # of dimensions not consistent"
);
constexpr
index_t
low_id_max
=
Coefficents
.
Back
()
+
accumulate_on_sequence
(
UpLengths
{}
*
Coefficients
::
PopBack
(),
math
::
plus
<
index_t
>
{},
Number
<
0
>
{});
static_assert
(
low_id_max
<
LowLength
,
"wrong! lower-id will go out of range"
);
}
__host__
__device__
static
constexpr
auto
GetUpperNumOfDimension
(){
return
Number
<
nDimUp
>
{}};
__host__
__device__
static
constexpr
auto
GetLowerNumOfDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
Sequence
<
LowLength
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
GetLowerId
(
UpperId
id_up
)
{
LowerId
id_low
{
mCoefficients
[
nDimUp
]};
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
id_low
[
0
]
+=
id_up
[
idim
]
*
mCoefficients
[
idim
];
});
return
id_low
;
}
__host__
__device__
static
constexpr
auto
GetLowerIdDiff
(
UpperId
id_up_diff
)
{
LowerId
id_low_diff
{
0
};
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
id_low_diff
[
0
]
+=
id_up_diff
[
idim
]
*
mCoefficients
[
idim
];
});
return
id_low_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
0c05f427
...
...
@@ -3,75 +3,159 @@
#include "common_header.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
template
<
class
Lengths
,
class
Stride
s
>
template
<
class
...
NativeDimension
s
>
struct
NativeTensorDescriptor
{
using
type
=
NativeTensorDescriptor
;
static
constexpr
index_t
nDim
=
Lengths
::
GetSize
();
static
constexpr
auto
mDimensions
=
Tuple
<
NativeDimensions
...
>
;
static
constexpr
index_t
nDim
=
mDimensions
::
GetSize
();
using
I
d
=
MultiIndex
<
nDim
>
;
using
I
ndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
// not implemented
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
Strides
{};
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
// not implemented
}
__host__
__device__
static
constexpr
auto
GetLength
(
index_t
IDim
)
{
return
Lengths
{}[
IDim
];
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
return
mDimensions
.
Get
(
Number
<
IDim
>
{}).
GetLength
();
}
__host__
__device__
static
constexpr
auto
GetStride
(
index_t
IDim
)
{
return
Strides
{}[
IDim
];
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
return
mDimensions
.
Get
(
Number
<
IDim
>
{}).
GetStride
();
}
__host__
__device__
static
constexpr
index_t
GetOffset
(
I
d
id
)
__host__
__device__
static
constexpr
index_t
GetOffset
(
I
ndex
id
x
)
{
// not implemented
index_t
offset
=
0
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
offset
+=
idx
[
idim
]
*
GetStride
(
idim
);
});
return
offset
;
}
__host__
__device__
static
constexpr
index_t
GetOffsetDiff
(
Index
idx_diff
)
{
index_t
offset_diff
=
0
;
static_for
<
0
,
nDim
,
1
>
{}(
[
&
](
auto
idim
)
{
offset_diff
+=
idx_diff
[
idim
]
*
GetStride
(
idim
);
});
return
offset_diff
;
}
__host__
__device__
static
constexpr
auto
AreUpperIndex2OffsetTransformLinear
();
{
// TODO: re-implement "Sequence", so that it can take other data-type (including bool) as
// element
return
uniform_sequence_gen
<
nDim
,
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetIndependentDimensionGroups
()
{
// not implemented, should return Tuple<Sequence<0>, Sequence<1>, ...>
return
xxx
;
}
};
// LowerTensorDescriptor
// Transforms: std::tuple<DimensionTransforms...>
// LowerIds: std::tuple<Sequence<...>>
// UpperIds: std::tuple<Sequence<...>>
template
<
class
LowTensorDescriptor
,
class
Transforms
,
class
LowDimensionMasks
,
class
UpDimensionMasks
>
// LowerDimensionIds: std::tuple<Sequence<...>>
// UpperDimensionIds: std::tuple<Sequence<...>>
template
<
class
LowTensorDescriptor
,
class
Transforms
,
class
LowDimensionIds
,
class
UpDimensionIds
>
struct
TransformedTensorDescriptor
{
using
type
=
TransformedTensorDescriptor
;
static
constexpr
index_t
nDimUp
=
xxxx
;
static
constexpr
index_t
nDimLow
=
xxx
;
static
constexpr
index_t
nDimUp
=
GetUpperNumOfDimension
()
;
static
constexpr
index_t
nDimLow
=
GetLowerNumOfDimension
()
;
static
constexpr
index_t
nTransform
=
Transforms
::
GetSize
();
using
UpperI
d
=
MultiIndex
<
nDimUp
>
;
using
LowerI
d
=
MultiIndex
<
nDimLow
>
;
using
UpperI
ndex
=
MultiIndex
<
nDimUp
>
;
using
LowerI
ndex
=
MultiIndex
<
nDimLow
>
;
__host__
__device__
static
constexpr
TransformedTensorDescriptor
()
{
static_assert
(
nTransform
==
Transforms
::
GetSize
()
&&
nTransform
==
LowDimension
Mask
s
::
GetSize
()
&&
nTransform
==
UpDimension
Mask
s
::
GetSize
(),
nTransform
==
LowDimension
Id
s
::
GetSize
()
&&
nTransform
==
UpDimension
Id
s
::
GetSize
(),
"wrong! # of transformations not the same"
);
// TODO: sanity check: LowDimension
Mask
s should include all low-dimensions,
// UpDimension
Mask
s should include all up-dimensions
// TODO: sanity check: LowDimension
Id
s should include all low-dimensions,
// UpDimension
Id
s should include all up-dimensions
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation,
// a low-dimension should be associated with only one transformation
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
constexpr
auto
low_active_dims
=
unique_sort_sequence
(
merge_tuple_of_sequences
(
LowDimensionIds
{}),
math
::
less
<
index_t
>
{});
return
low_active_dims
.
GetSize
();
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
constexpr
auto
up_active_dims
=
unique_sort_sequence
(
merge_tuple_of_sequences
(
UpDimensionIds
{}),
math
::
less
<
index_t
>
{});
return
up_active_dims
.
GetSize
();
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
// not implemented
return
GetNumOfUpperDimension
();
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
// not implemented
struct
lambda_get_upper_lengths
{
template
<
class
Transform
>
__host__
__device__
constexpr
auto
operator
()(
Transform
tran
)
const
{
return
tran
.
GetUpperLengths
();
}
};
constexpr
auto
tuple_of_upper_lengths
=
transform_tuple
(
Transforms
,
lambda_get_upper_lengths
{});
constexpr
auto
all_upper_lengths
=
merge_tuple_of_sequences
(
tuple_of_upper_lengths
);
constexpr
auto
all_upper_dimension_ids
=
merge_tuple_of_sequences
(
UpDimensionIds
{});
// TODO: sanity-check all_upper_dimension_ids contain all upper-dimensions
// TODO: sanity-check all_upper_lengths have no conflicting upper-length
using
sort_dimension_ids
=
sequence_unique_sort
<
decltype
(
all_upper_dimension_ids
),
math
::
less
<
index_t
>>
;
constexpr
auto
sorted_upper_dimension_ids
=
typename
sort_dimension_ids
::
type
;
constexpr
auto
sorted2unsorted_map
=
typename
sort_dimension_ids
::
sorted2unsorted_map_type
;
constexpr
auto
sorted_upper_lengths
=
sequence_element_pick
(
all_upper_lengths
,
sorted2unsorted_map
);
return
sorted_upper_lengths
;
}
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
...
...
@@ -79,17 +163,57 @@ struct TransformedTensorDescriptor
return
LowTensorDescriptor
{};
}
__host__
__device__
static
constexpr
index_t
GetLowerI
d
(
UpperI
d
id_up
)
__host__
__device__
static
constexpr
index_t
GetLowerI
ndex
(
UpperI
ndex
id
x
_up
)
{
// not implemented
LowerIndex
idx_low
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
::
Get
(
itran
);
constexpr
auto
idx_low_part
=
pick_array_element
(
idx_low
,
LowDimensionIds
::
Get
(
itran
));
constexpr
auto
idx_up_part
=
pick_array_element
(
idx_up
,
UpDimensionIds
::
Get
(
itran
));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part
=
tran
.
GetLowerIndex
(
idx_up_part
);
});
return
idx_low
;
}
__host__
__device__
static
constexpr
index_t
GetLowerIndexDiff
(
UpperIndex
idx_up_diff
,
LowerIndex
idx_low_old
)
{
LowerIndex
idx_low_diff
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
::
Get
(
itran
);
constexpr
auto
idx_up_diff_part
=
pick_array_element
(
idx_up_diff
,
UpDimensionIds
::
Get
(
itran
));
constexpr
auto
idx_low_diff_part
=
pick_array_element
(
idx_low_diff
,
LowDimensionIds
::
Get
(
itran
));
constexpr
auto
idx_low_old_part
=
pick_array_element
(
idx_low_old
,
LowDimensionIds
::
Get
(
itran
));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part
=
tran
.
GetLowerIndex
(
idx_up_diff_part
,
idx_low_old_part
);
});
return
idx_low_diff
;
}
__host__
__device__
static
constexpr
index_t
GetOffset
(
UpperI
d
id_up
)
__host__
__device__
static
constexpr
index_t
GetOffset
(
UpperI
ndex
id
x
_up
)
{
return
GetLowerTensorDescriptor
().
GetOffset
(
GetLowerI
d
(
id_up
));
return
GetLowerTensorDescriptor
().
GetOffset
(
GetLowerI
ndex
(
id
x
_up
));
}
__host__
__device__
static
constexpr
auto
AreUpperI
d
2OffsetLinear
();
__host__
__device__
static
constexpr
auto
AreUpperI
ndex
2Offset
Transform
Linear
();
{
// not implemented
}
...
...
composable_kernel/include/utility/Array.hpp
View file @
0c05f427
...
...
@@ -79,6 +79,56 @@ struct Array
}
};
// A: Array
// Picks: Sequence<...>
template
<
class
Arr
,
class
Picks
>
ArrayElementPicker
{
__host__
__device__
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mData
{
array
}
{
constexpr
index_t
imax
=
accumulate_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
static_assert
(
imax
<
Picks
::
GetSize
(),
"wrong! exceeding max id"
);
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
Picks
::
GetSize
();
}
template
<
index_t
I
>
__host__
__device__
constexpr
TData
operator
[](
Number
<
I
>
)
const
{
constexpr
auto
IP
=
Picks
::
Get
(
Number
<
I
>
{});
return
mData
[
IP
];
}
__host__
__device__
constexpr
TData
operator
[](
index_t
i
)
const
{
constexpr
index_t
ip
=
Picks
{}[
i
];
return
mData
[
ip
];
}
template
<
index_t
I
>
__host__
__device__
TData
&
operator
()(
Number
<
I
>
)
{
constexpr
auto
IP
=
Picks
::
Get
(
Number
<
I
>
{});
return
mData
[
IP
];
}
__host__
__device__
TData
&
operator
()(
index_t
i
)
{
constexpr
index_t
ip
=
Picks
{}[
i
];
return
mData
[
ip
];
}
Arr
&
mData
;
};
template
<
class
Arr
,
class
Picks
>
__host__
__device__
constexpr
auto
pick_array_element
(
Arr
&
a
,
Picks
)
{
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
{
...
...
composable_kernel/include/utility/Sequence.hpp
View file @
0c05f427
...
...
@@ -6,6 +6,9 @@
namespace
ck
{
template
<
index_t
,
index_t
,
index_t
>
struct
static_for
;
template
<
index_t
...>
struct
Sequence
;
...
...
@@ -294,6 +297,18 @@ struct sequence_reverse<Sequence<I0, I1>>
using
type
=
Sequence
<
I1
,
I0
>
;
};
template
<
class
Seq
,
class
Compare
>
struct
sequence_sort
{
// not implemented
};
template
<
class
Seq
,
class
Compare
>
struct
sequence_unique_sort
{
// not implemented
};
template
<
class
Seq
>
struct
is_valid_sequence_map
{
...
...
@@ -486,6 +501,35 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
}
template
<
class
Seq
,
class
Reduce
>
struct
lambda_accumulate_on_sequence
{
const
Reduce
&
f
;
index_t
&
result
;
__host__
__device__
constexpr
lambda_accumulate_on_sequence
(
const
Reduce
&
f_
,
index_t
&
result_
)
:
f
(
f_
),
result
(
result_
)
{
}
template
<
class
IDim
>
__host__
__device__
constexpr
index_t
operator
()(
IDim
)
const
{
return
result
=
f
(
result
,
Seq
::
Get
(
IDim
{}));
}
};
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
index_t
accumulate_on_sequence
(
Seq
,
Reduce
f
,
Number
<
Init
>
/*initial_value*/
)
{
index_t
result
=
Init
;
static_for
<
0
,
Seq
::
mSize
,
1
>
{}(
lambda_accumulate_on_sequence
<
Seq
,
Reduce
>
(
f
,
result
));
return
result
;
}
template
<
index_t
...
Xs
>
__host__
__device__
void
print_Sequence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
{
...
...
composable_kernel/include/utility/functional2.hpp
View file @
0c05f427
...
...
@@ -37,34 +37,5 @@ struct static_for
}
};
template
<
class
Seq
,
class
Reduce
>
struct
lambda_accumulate_on_sequence
{
const
Reduce
&
f
;
index_t
&
result
;
__host__
__device__
constexpr
lambda_accumulate_on_sequence
(
const
Reduce
&
f_
,
index_t
&
result_
)
:
f
(
f_
),
result
(
result_
)
{
}
template
<
class
IDim
>
__host__
__device__
constexpr
index_t
operator
()(
IDim
)
const
{
return
result
=
f
(
result
,
Seq
::
Get
(
IDim
{}));
}
};
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
index_t
accumulate_on_sequence
(
Seq
,
Reduce
f
,
Number
<
Init
>
/*initial_value*/
)
{
index_t
result
=
Init
;
static_for
<
0
,
Seq
::
mSize
,
1
>
{}(
lambda_accumulate_on_sequence
<
Seq
,
Reduce
>
(
f
,
result
));
return
result
;
}
}
// namespace ck
#endif
composable_kernel/include/utility/math.hpp
View file @
0c05f427
...
...
@@ -31,6 +31,12 @@ struct multiplies
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
};
template
<
class
T
>
struct
maxer
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>=
b
?
a
:
b
;
}
};
template
<
class
T
>
struct
integer_divide_ceiler
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment