Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
ca42e910
You need to sign in or sign up before continuing.
Commit
ca42e910
authored
Sep 10, 2019
by
Chao Liu
Browse files
adding merge transform
parent
7a7fe160
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
551 additions
and
253 deletions
+551
-253
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
..._convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
+44
-14
composable_kernel/include/tensor_description/dimension.hpp
composable_kernel/include/tensor_description/dimension.hpp
+2
-2
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+76
-37
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+101
-63
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+57
-2
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/array.hpp
+3
-4
composable_kernel/include/utility/array_helper.hpp
composable_kernel/include/utility/array_helper.hpp
+3
-3
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+2
-2
composable_kernel/include/utility/functional.hpp
composable_kernel/include/utility/functional.hpp
+1
-1
composable_kernel/include/utility/functional2.hpp
composable_kernel/include/utility/functional2.hpp
+1
-1
composable_kernel/include/utility/functional3.hpp
composable_kernel/include/utility/functional3.hpp
+2
-2
composable_kernel/include/utility/functional4.hpp
composable_kernel/include/utility/functional4.hpp
+2
-2
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+1
-0
composable_kernel/include/utility/sequence.hpp
composable_kernel/include/utility/sequence.hpp
+242
-108
composable_kernel/include/utility/sequence_helper.hpp
composable_kernel/include/utility/sequence_helper.hpp
+2
-2
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+7
-7
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+3
-1
driver/src/driver.cpp
driver/src/driver.cpp
+2
-2
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
View file @
ca42e910
...
@@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
...
@@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
#elif
1
#elif
1
// create a native tensor descriptor
// create a native tensor descriptor
constexpr
auto
in_c_h_w_n_global_desc
=
constexpr
auto
in_c_h_w_n_global_desc
=
make_
N
ative
T
ensor
D
escriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
make_
n
ative
_t
ensor
_d
escriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_c_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_c_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_c_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_c_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
in_c_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
in_c_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
auto
pad_h_w
=
Pad
<
Sequence
<
Hi
,
Wi
>
,
LowerPads
,
UpperPads
>
{};
// transformation: {c, h, w, n} --> {n, c, hp, wp}
constexpr
auto
pass_c
=
PassThrough
<
C
>
{};
// {h, w} --> {hp, wp}, {c} --> {c}, {n} --> {n}
constexpr
auto
pass_n
=
PassThrough
<
N
>
{};
constexpr
auto
in_n_c_hp_wp_global_desc
=
transform_tensor_descriptor
(
in_c_h_w_n_global_desc
,
make_tuple
(
Pad
<
Sequence
<
Hi
,
Wi
>
,
LowerPads
,
UpperPads
>
{},
PassThrough
<
C
>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
1
,
2
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
constexpr
auto
trans
=
make_tuple
(
pass_c
,
pad_h_w
,
pass_n
);
#if 1
constexpr
auto
lower_dim_groups
=
// transformation: {n, c, hp, wp} --> {c, b}
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{});
// {n, hp, wp} --> {b}, {c} --> {c}
constexpr
auto
upper_dim_groups
=
constexpr
auto
in_c_b_global_desc
=
transform_tensor_descriptor
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{});
in_n_c_hp_wp_global_desc
,
make_tuple
(
Merge
<
decltype
(
in_n_c_hp_wp_global_desc
.
GetLengths
(
I0
,
I2
,
I3
))
>
{},
constexpr
auto
in_c_h_w_n_padded_global_desc
=
transform_tensor_descriptor
(
PassThrough
<
in_n_c_hp_wp_global_desc
.
GetLength
(
I1
)
>
{}),
in_c_h_w_n_global_desc
,
trans
,
lower_dim_groups
,
upper_dim_groups
);
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
#endif
#if 1
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
{
// 0
print_tensor_descriptor
(
"in_c_h_w_n_global_desc"
,
in_c_h_w_n_global_desc
);
print_tensor_descriptor
(
"in_c_h_w_n_global_desc"
,
in_c_h_w_n_global_desc
);
printf
(
"offset: %lu
\n
"
,
in_c_h_w_n_global_desc
.
GetOffset
({
1
,
2
,
3
,
4
}));
// 1
print_tensor_descriptor
(
"in_n_c_hp_wp_global_desc"
,
in_n_c_hp_wp_global_desc
);
// 2
print_tensor_descriptor
(
"in_c_b_global_desc"
,
in_c_b_global_desc
);
constexpr
auto
idx2
=
MultiIndex
<
2
>
{
1
,
4
*
(
16
*
16
)
+
5
*
16
+
6
};
auto
idx1
=
in_c_b_global_desc
.
CalculateLowerIndex
(
idx2
);
auto
idx0
=
in_c_b_global_desc
.
GetLowerTensorDescriptor
().
CalculateLowerIndex
(
idx1
);
printf
(
"padded offset: %lu
\n
"
,
in_c_h_w_n_padded_global_desc
.
GetOffset
({
1
,
4
,
5
,
4
}));
print_array
(
"idx2: "
,
idx2
);
print_array
(
"idx1: "
,
idx1
);
print_array
(
"idx0: "
,
idx0
);
printf
(
"in_c_b_global_desc offset: %lu
\n
"
,
in_c_b_global_desc
.
CalculateOffset
(
idx2
));
}
#else
{
index_t
c
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
index_t
h
=
static_cast
<
index_t
>
(
threadIdx
.
y
);
index_t
w
=
static_cast
<
index_t
>
(
threadIdx
.
z
);
p_out_global
[
0
]
=
in_n_c_h_w_padded_global_desc
.
CalculateOffset
({
1
,
c
,
h
,
w
});
}
}
#endif
#endif
#endif
}
}
#endif
#endif
...
...
composable_kernel/include/tensor_description/dimension.hpp
View file @
ca42e910
...
@@ -18,9 +18,9 @@ struct NativeDimension
...
@@ -18,9 +18,9 @@ struct NativeDimension
__host__
__device__
static
constexpr
auto
GetStride
()
{
return
Number
<
Stride
>
{};
}
__host__
__device__
static
constexpr
auto
GetStride
()
{
return
Number
<
Stride
>
{};
}
__host__
__device__
static
constexpr
index_t
Get
Offset
(
index_t
i
)
{
return
i
*
Stride
;
}
__host__
__device__
static
constexpr
index_t
Calculate
Offset
(
index_t
i
)
{
return
i
*
Stride
;
}
__host__
__device__
static
constexpr
index_t
Get
OffsetDiff
(
index_t
i_diff
)
__host__
__device__
static
constexpr
index_t
Calculate
OffsetDiff
(
index_t
i_diff
)
{
{
return
i_diff
*
Stride
;
return
i_diff
*
Stride
;
}
}
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
ca42e910
...
@@ -22,9 +22,12 @@ struct PassThrough
...
@@ -22,9 +22,12 @@ struct PassThrough
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerIndex
(
UpperIndex
idx_up
)
{
return
idx_up
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
UpperIndex
idx_up
)
{
return
idx_up
;
}
__host__
__device__
static
constexpr
auto
Get
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
{
{
return
idx_up_diff
;
return
idx_up_diff
;
}
}
...
@@ -36,7 +39,7 @@ struct PassThrough
...
@@ -36,7 +39,7 @@ struct PassThrough
template
<
typename
LowLengths
,
typename
LeftPads
,
typename
RightPads
>
template
<
typename
LowLengths
,
typename
LeftPads
,
typename
RightPads
>
struct
Pad
struct
Pad
{
{
static
constexpr
index_t
nDim
=
LowLengths
::
Get
Size
();
static
constexpr
index_t
nDim
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
...
@@ -52,12 +55,12 @@ struct Pad
...
@@ -52,12 +55,12 @@ struct Pad
return
GetLowerLengths
()
+
LeftPads
{}
+
RightPads
{};
return
GetLowerLengths
()
+
LeftPads
{}
+
RightPads
{};
}
}
__host__
__device__
static
constexpr
auto
Get
LowerIndex
(
UpperIndex
idx_up
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndex
(
UpperIndex
idx_up
)
{
{
return
idx_up
-
LeftPads
{};
return
idx_up
-
LeftPads
{};
}
}
__host__
__device__
static
constexpr
auto
Get
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
{
{
return
idx_up_diff
;
return
idx_up_diff
;
}
}
...
@@ -65,21 +68,20 @@ struct Pad
...
@@ -65,21 +68,20 @@ struct Pad
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
};
};
#if 0
// LowLengths: Sequence<...>
// LowLengths: Sequence<...>
template
<
typename
LowLengths
>
template
<
typename
LowLengths
>
struct
Merge
struct
Merge
{
{
static constexpr index_t nDimLow = LowLengths::
Get
Size();
static
constexpr
index_t
nDimLow
=
LowLengths
::
Size
();
static
constexpr
index_t
nDimUp
=
1
;
static
constexpr
index_t
nDimUp
=
1
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
...
@@ -88,18 +90,56 @@ struct Merge
...
@@ -88,18 +90,56 @@ struct Merge
GetLowerLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
GetLowerLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
}
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
// emulate constexpr lambda
template
<
typename
PseudoLowStrides
>
struct
lambda_CalculateLowerIndex
{
index_t
&
itmp
;
LowerIndex
&
idx_low
;
__host__
__device__
explicit
constexpr
lambda_CalculateLowerIndex
(
index_t
&
itmp_
,
LowerIndex
&
idx_low_
)
:
itmp
(
itmp_
),
idx_low
(
idx_low_
)
{
}
template
<
typename
IDim
>
__host__
__device__
constexpr
void
operator
()(
IDim
idim
)
const
{
constexpr
index_t
stride
=
PseudoLowStrides
::
At
(
idim
);
idx_low
(
idim
)
=
itmp
/
stride
;
itmp
-=
idx_low
[
idim
]
*
stride
;
}
};
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
LowerIndex
idx_low
;
LowerIndex
idx_low
;
// not implemeneted
index_t
itmp
=
idx_up
[
0
];
constexpr
auto
pseudo_low_strides
=
reverse_inclusive_scan_sequence
(
GetLowerLengths
().
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
// calculate index in each of the dimensions in the order of their dimension
#if 1
static_for
<
0
,
nDimLow
-
1
,
1
>
{}(
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
idx_low
(
nDimLow
-
1
)
=
itmp
/
pseudo_low_strides
[
nDimLow
-
1
];
#else
static_for
<
0
,
nDimLow
,
1
>
{}(
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
#endif
return
idx_low
;
return
idx_low
;
}
}
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
__host__ __device__ static constexpr auto
Get
LowerIndexDiff(UpperIndex idx_up_diff,
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
LowerIndex idx_low_old)
const
LowerIndex
&
idx_low_old
)
{
{
LowerIndex
idx_low_diff
;
LowerIndex
idx_low_diff
;
...
@@ -110,49 +150,48 @@ struct Merge
...
@@ -110,49 +150,48 @@ struct Merge
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
};
};
#endif
// UpLengths: Sequence<...>
// UpLengths: Sequence<...>
template
<
index_t
LowLength
,
typename
UpLengths
>
template
<
typename
UpLengths
>
struct
Unmerge
struct
Unmerge
{
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpLengths
::
Get
Size
();
static
constexpr
index_t
nDimUp
=
UpLengths
::
Size
();
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
constexpr
Unmerge
()
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
{
static_assert
(
LowLength
==
accumulate_on_sequence
(
UpLengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{}),
"wrong! UpLengths need to be "
);
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
constexpr
index_t
low_length
=
accumulate_on_sequence
(
UpLengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
Sequence
<
LowLength
>
{};
}
return
Sequence
<
low_length
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
Get
LowerIndex
(
UpperIndex
idx_up
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
constexpr
auto
scans
=
typename
sequence_reverse_inclusive_scan
<
UpLengths
,
math
::
multiplies
<
index_t
>
,
1
>::
type
{};
LowerIndex
idx_low
{
0
};
LowerIndex
idx_low
{
0
};
static_for
<
0
,
nDimUp
,
1
>
{}([
&
](
auto
idim
)
{
idx_low
(
0
)
+=
idx_up
[
idim
]
*
scans
[
idim
];
});
constexpr
auto
pseudo_up_strides
=
typename
sequence_reverse_inclusive_scan
<
UpLengths
,
math
::
multiplies
<
index_t
>
,
1
>::
type
{};
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low
(
0
)
+=
idx_up
[
idim
]
*
pseudo_up_strides
[
idim
];
});
return
idx_low
;
return
idx_low
;
}
}
__host__
__device__
static
constexpr
auto
Get
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
)
{
{
return
Get
LowerIndex
(
idx_up_diff
);
return
Calculate
LowerIndex
(
idx_up_diff
);
}
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
...
@@ -165,12 +204,12 @@ template <index_t LowLength, typename UpLengths, typename Coefficients>
...
@@ -165,12 +204,12 @@ template <index_t LowLength, typename UpLengths, typename Coefficients>
struct
Embed
struct
Embed
{
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpLengths
::
Get
Size
();
static
constexpr
index_t
nDimUp
=
UpLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
constexpr
Embed
()
__host__
__device__
explicit
constexpr
Embed
()
{
{
static_assert
(
UpLengths
::
GetSize
()
==
nDimUp
&&
Coefficients
::
GetSize
()
==
nDimUp
+
1
,
static_assert
(
UpLengths
::
GetSize
()
==
nDimUp
&&
Coefficients
::
GetSize
()
==
nDimUp
+
1
,
"wrong! # of dimensions not consistent"
);
"wrong! # of dimensions not consistent"
);
...
@@ -191,7 +230,7 @@ struct Embed
...
@@ -191,7 +230,7 @@ struct Embed
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
Get
LowerIndex
(
UpperIndex
idx_up
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
LowerIndex
idx_low
(
Coefficients
{}[
nDimUp
]);
LowerIndex
idx_low
(
Coefficients
{}[
nDimUp
]);
...
@@ -201,7 +240,7 @@ struct Embed
...
@@ -201,7 +240,7 @@ struct Embed
return
idx_low
;
return
idx_low
;
}
}
__host__
__device__
static
constexpr
auto
Get
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
)
{
{
LowerIndex
idx_low_diff
{
0
};
LowerIndex
idx_low_diff
{
0
};
...
...
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
ca42e910
...
@@ -18,47 +18,53 @@ struct NativeTensorDescriptor
...
@@ -18,47 +18,53 @@ struct NativeTensorDescriptor
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
struct
lambda_GetLength
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
{
template
<
typename
IDim
>
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetLength
();
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
}
{
return
GetLength
(
IDim
{});
}
};
__host__
__device__
static
constexpr
auto
GetLengths
()
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
{
return
typename
sequence_gen
<
nDim
,
lambda_GetLength
>::
type
{}
;
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetStride
()
;
}
}
struct
lambda_GetStride
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Sequence
<
IDims
...
>
)
{
{
template
<
typename
IDim
>
return
Sequence
<
GetLength
(
Number
<
IDims
>
{})...
>
{};
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
}
{
return
GetStride
(
IDim
{});
}
};
__host__
__device__
static
constexpr
auto
GetStrides
()
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetStrides
(
Sequence
<
IDims
...
>
)
{
{
return
typename
sequence_gen
<
nDim
,
lambda_GetStride
>::
type
{};
return
Sequence
<
GetStride
(
Number
<
IDims
>
{})...
>
{};
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetLength
s
(
Number
<
IDim
>
,
Number
<
IDims
>
...
)
{
{
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetLength
(
);
return
GetLengths
(
Sequence
<
IDim
,
IDims
...
>
{}
);
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetStride
s
(
Number
<
IDim
>
,
Number
<
IDims
>
...
)
{
{
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetStride
(
);
return
GetStrides
(
Sequence
<
IDim
,
IDims
...
>
{}
);
}
}
__host__
__device__
static
constexpr
index_t
GetOffset
(
const
Index
&
idx
)
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetLengths
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{});
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
GetStrides
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{});
}
__host__
__device__
static
constexpr
index_t
CalculateOffset
(
const
Index
&
idx
)
{
{
index_t
offset
=
0
;
index_t
offset
=
0
;
...
@@ -67,7 +73,7 @@ struct NativeTensorDescriptor
...
@@ -67,7 +73,7 @@ struct NativeTensorDescriptor
return
offset
;
return
offset
;
}
}
__host__
__device__
static
constexpr
index_t
Get
OffsetDiff
(
const
Index
&
idx_diff
)
__host__
__device__
static
constexpr
index_t
Calculate
OffsetDiff
(
const
Index
&
idx_diff
)
{
{
index_t
offset_diff
=
0
;
index_t
offset_diff
=
0
;
...
@@ -161,8 +167,10 @@ struct TransformedTensorDescriptor
...
@@ -161,8 +167,10 @@ struct TransformedTensorDescriptor
// UpDimensionIds should include all up-dimensions
// UpDimensionIds should include all up-dimensions
// TODO: sanity check: while a up-dimension could be associated with multille
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation,
// transformation, a low-dimension should be associated with only one transformation
// a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
}
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
...
@@ -170,49 +178,78 @@ struct TransformedTensorDescriptor
...
@@ -170,49 +178,78 @@ struct TransformedTensorDescriptor
return
GetNumOfUpperDimension
();
return
GetNumOfUpperDimension
();
}
}
#if 0
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
__host__ __device__ static constexpr auto GetUpperLengths()
{
return
LowTensorDescriptor
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
{
struct lambda_get_upper_lengths
return
GetLowerTensorDescriptor
().
GetLengths
();
}
struct
lambda_GetUpperLengths
{
template
<
typename
Transform
>
__host__
__device__
constexpr
auto
operator
()(
const
Transform
&
tran
)
const
{
{
template <typename Transform>
return
tran
.
GetUpperLengths
();
__host__ __device__ constexpr auto operator()(Transform tran) const
}
{
};
return tran.GetUpperLengths();
}
};
constexpr auto tuple_of_upper_lengths =
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
transform_tuple(Transforms, lambda_get_upper_lengths{});
{
constexpr
auto
tuple_of_up_lengths
=
transform_tuple
(
lambda_GetUpperLengths
{},
Transforms
{});
constexpr auto
all_upper
_lengths =
merge_tuple_of
_sequences
(
tuple_of_up
per
_lengths);
constexpr
auto
mingled_up
_lengths
=
unpack
(
lambda_merge
_sequences
{},
tuple_of_up_lengths
);
constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{});
constexpr
auto
mingled_up_dimension_ids
=
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{});
// TODO: sanity-check
all_upper
_dimension_ids contain all upper-dimensions
// TODO: sanity-check
mingled_up
_dimension_ids contain all upper-dimensions
// TODO: sanity-check
all_upper
_lengths have no conflicting upper-length
// TODO: sanity-check
mingled_up
_lengths have no conflicting upper-length
using sort_dimension_ids =
// sort by upper-dimension-ids
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
using
sort_up_dimension_ids
=
sequence_unique_sort
<
decltype
(
mingled_up_dimension_ids
),
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>
;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
static_assert
(
is_same
<
typename
sort_up_dimension_ids
::
type
,
typename
arithmetic_sequence_gen
<
0
,
nDimUp
,
1
>::
type
>
{},
"wrong! UpDimensionIds is not configured correctly"
);
constexpr auto sorted_upper_lengths =
constexpr
auto
sorted2unsorted_map
=
typename
sort_up_dimension_ids
::
sorted2unsorted_map
{};
sequence_element_pick(all_upper_lengths, sorted2unsorted_map);
return sorted_upper_lengths;
constexpr
auto
sorted_up_lengths
=
pick_sequence_elements
(
mingled_up_lengths
,
sorted2unsorted_map
);
return
sorted_up_lengths
;
}
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetUpperLengths
();
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetUpperLengths
();
}
#endif
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
{
return
LowTensorDescriptor
{};
return
GetLengths
()[
IDim
];
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Sequence
<
IDims
...
>
)
{
return
Sequence
<
GetLength
(
Number
<
IDims
>
{})...
>
{};
}
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Number
<
IDim
>
,
Number
<
IDims
>
...)
{
return
GetLengths
(
Sequence
<
IDim
,
IDims
...
>
{});
}
}
__host__
__device__
static
constexpr
LowerIndex
GetLowerIndex
(
const
UpperIndex
&
idx_up
)
// TODO: right now return value is constexpr because use of non-constepxr lambda
__host__
__device__
static
constexpr
LowerIndex
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
LowerIndex
idx_low
;
LowerIndex
idx_low
;
...
@@ -225,14 +262,15 @@ struct TransformedTensorDescriptor
...
@@ -225,14 +262,15 @@ struct TransformedTensorDescriptor
// this assume each lower (single) index is only assocaited with one transformation,
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
// of TransformedTensorDescriptor
idx_low_part
=
tran
.
Get
LowerIndex
(
to_array
(
idx_up_part
));
idx_low_part
=
tran
.
Calculate
LowerIndex
(
to_array
(
idx_up_part
));
});
});
return
idx_low
;
return
idx_low
;
}
}
__host__
__device__
static
constexpr
LowerIndex
GetLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
// TODO: right now return value is constexpr because use of non-constepxr lambda
const
LowerIndex
&
idx_low_old
)
__host__
__device__
static
constexpr
LowerIndex
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
LowerIndex
&
idx_low_old
)
{
{
LowerIndex
idx_low_diff
;
LowerIndex
idx_low_diff
;
...
@@ -250,15 +288,15 @@ struct TransformedTensorDescriptor
...
@@ -250,15 +288,15 @@ struct TransformedTensorDescriptor
// this assume each lower (single) index is associated with only one transformation,
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
// of TransformedTensorDescriptor
idx_low_diff_part
=
tran
.
Get
LowerIndex
(
idx_up_diff_part
,
idx_low_old_part
);
idx_low_diff_part
=
tran
.
Calculate
LowerIndex
(
idx_up_diff_part
,
idx_low_old_part
);
});
});
return
idx_low_diff
;
return
idx_low_diff
;
}
}
__host__
__device__
static
constexpr
index_t
Get
Offset
(
const
UpperIndex
&
idx_up
)
__host__
__device__
static
constexpr
index_t
Calculate
Offset
(
const
UpperIndex
&
idx_up
)
{
{
return
GetLowerTensorDescriptor
().
GetOffset
(
Get
LowerIndex
(
idx_up
));
return
GetLowerTensorDescriptor
().
CalculateOffset
(
Calculate
LowerIndex
(
idx_up
));
}
}
#if 0
#if 0
...
@@ -286,14 +324,14 @@ struct TransformedTensorDescriptor
...
@@ -286,14 +324,14 @@ struct TransformedTensorDescriptor
};
};
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
constexpr
auto
make_
N
ative
T
ensor
D
escriptor
(
Sequence
<
Lengths
...
>
,
__host__
__device__
constexpr
auto
make_
n
ative
_t
ensor
_d
escriptor
(
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
Sequence
<
Strides
...
>
)
{
{
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
}
}
template
<
typename
Lengths
>
template
<
typename
Lengths
>
__host__
__device__
constexpr
auto
make_
N
ative
T
ensor
D
escriptor_packed
(
Lengths
)
__host__
__device__
constexpr
auto
make_
n
ative
_t
ensor
_d
escriptor_packed
(
Lengths
)
{
{
constexpr
auto
strides
=
reverse_inclusive_scan_sequence
(
constexpr
auto
strides
=
reverse_inclusive_scan_sequence
(
Lengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
Lengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
...
...
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
View file @
ca42e910
...
@@ -7,12 +7,19 @@
...
@@ -7,12 +7,19 @@
namespace
ck
{
namespace
ck
{
template
<
typename
...
NativeDimensions
>
template
<
typename
...
NativeDimensions
>
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
__host__
__device__
void
NativeTensorDescriptor
<
NativeDimensions
...
>
desc
)
print_tensor_descriptor
(
const
char
*
s
,
const
NativeTensorDescriptor
<
NativeDimensions
...
>
&
desc
)
{
{
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
(),
desc
.
GetStrides
());
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
(),
desc
.
GetStrides
());
}
}
template
<
typename
...
Ts
>
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
const
TransformedTensorDescriptor
<
Ts
...
>&
desc
)
{
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
());
}
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
void
__host__
__device__
void
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
...
@@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strid
...
@@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strid
});
});
}
}
template
<
index_t
...
Lengths
>
__host__
__device__
void
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
)
{
constexpr
index_t
nDim
=
sizeof
...(
Lengths
);
static_assert
(
nDim
>
0
&&
nDim
<=
12
,
"wrong!"
);
static_if
<
nDim
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
2
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
3
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
4
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
5
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
6
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u},
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
7
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
11
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
12
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/
A
rray.hpp
→
composable_kernel/include/utility/
a
rray.hpp
View file @
ca42e910
#ifndef CK_ARRAY_HPP
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
#define CK_ARRAY_HPP
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -17,7 +17,7 @@ struct Array
...
@@ -17,7 +17,7 @@ struct Array
__host__
__device__
explicit
constexpr
Array
()
{}
__host__
__device__
explicit
constexpr
Array
()
{}
template
<
typename
X
,
typename
...
Xs
>
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
explicit
constexpr
Array
(
X
x
,
Xs
...
xs
)
__host__
__device__
constexpr
Array
(
X
x
,
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
x
),
static_cast
<
TData
>
(
xs
)...}
:
mData
{
static_cast
<
TData
>
(
x
),
static_cast
<
TData
>
(
xs
)...}
{
{
static_assert
(
sizeof
...(
Xs
)
+
1
==
NSize
,
"wrong! size"
);
static_assert
(
sizeof
...(
Xs
)
+
1
==
NSize
,
"wrong! size"
);
...
@@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
...
@@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
}
}
#if 1
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
auto
to_array
(
const
T
&
x
)
__host__
__device__
constexpr
auto
to_array
(
const
T
&
x
)
{
{
...
@@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x)
...
@@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x)
return
y
;
return
y
;
}
}
#endif
// TODO: remove this
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
{
{
...
...
composable_kernel/include/utility/array_helper.hpp
View file @
ca42e910
#ifndef CK_ARRAY_HELPER_HPP
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "
A
rray.hpp"
#include "
a
rray.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
T
,
index_t
NSize
>
template
<
typename
T
,
index_t
NSize
>
__host__
__device__
void
print_
A
rray
(
const
char
*
s
,
Array
<
T
,
NSize
>
a
)
__host__
__device__
void
print_
a
rray
(
const
char
*
s
,
Array
<
T
,
NSize
>
a
)
{
{
constexpr
index_t
nsize
=
a
.
GetSize
();
constexpr
index_t
nsize
=
a
.
GetSize
();
...
@@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
...
@@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
}
}
}
// namespace ck
}
// namespace ck
#endif
#endif
\ No newline at end of file
composable_kernel/include/utility/common_header.hpp
View file @
ca42e910
...
@@ -9,9 +9,9 @@
...
@@ -9,9 +9,9 @@
#include "tuple.hpp"
#include "tuple.hpp"
#include "math.hpp"
#include "math.hpp"
#include "vector_type.hpp"
#include "vector_type.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "sequence_helper.hpp"
#include "sequence_helper.hpp"
#include "
A
rray.hpp"
#include "
a
rray.hpp"
#include "array_helper.hpp"
#include "array_helper.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
...
...
composable_kernel/include/utility/functional.hpp
View file @
ca42e910
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#define CK_FUNCTIONAL_HPP
#define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "type.hpp"
#include "type.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/functional2.hpp
View file @
ca42e910
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#define CK_FUNCTIONAL2_HPP
#define CK_FUNCTIONAL2_HPP
#include "functional.hpp"
#include "functional.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/functional3.hpp
View file @
ca42e910
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "
A
rray.hpp"
#include "
a
rray.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/functional4.hpp
View file @
ca42e910
#ifndef CK_FUNCTIONAL4_HPP
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
#include "
A
rray.hpp"
#include "
a
rray.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/math.hpp
View file @
ca42e910
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "config.hpp"
#include "config.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
namespace
ck
{
namespace
ck
{
namespace
math
{
namespace
math
{
...
...
composable_kernel/include/utility/
S
equence.hpp
→
composable_kernel/include/utility/
s
equence.hpp
View file @
ca42e910
...
@@ -2,7 +2,9 @@
...
@@ -2,7 +2,9 @@
#define CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "math.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -155,8 +157,8 @@ struct Sequence
...
@@ -155,8 +157,8 @@ struct Sequence
static_assert
(
I
<
Size
(),
"wrong!"
);
static_assert
(
I
<
Size
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
Type
,
I
>
;
using
seq_split
=
sequence_split
<
Type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
SeqT
ype
0
{};
constexpr
auto
seq_left
=
typename
seq_split
::
left_t
ype
{};
constexpr
auto
seq_right
=
typename
seq_split
::
SeqT
ype
1
{}.
PopFront
();
constexpr
auto
seq_right
=
typename
seq_split
::
right_t
ype
{}.
PopFront
();
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
}
...
@@ -188,34 +190,34 @@ struct sequence_merge<Seq>
...
@@ -188,34 +190,34 @@ struct sequence_merge<Seq>
};
};
// generate sequence
// generate sequence
template
<
index_t
IBegin
,
index_t
NRemain
,
typename
F
>
template
<
index_t
NSize
,
typename
F
>
struct
sequence_gen
_impl
struct
sequence_gen
{
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
template
<
index_t
IBegin
,
index_t
NRemain
,
typename
G
>
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
struct
sequence_gen_impl
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
using
type
=
using
type
=
typename
sequence_merge
<
typename
sequence_merge
<
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
F
>::
type
,
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
G
>::
type
,
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
F
>::
type
>::
type
;
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
G
>::
type
>::
type
;
};
};
template
<
index_t
I
,
typename
F
>
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
1
,
F
>
struct
sequence_gen_impl
<
I
,
1
,
G
>
{
{
static
constexpr
index_t
Is
=
F
{}(
Number
<
I
>
{});
static
constexpr
index_t
Is
=
G
{}(
Number
<
I
>
{});
using
type
=
Sequence
<
Is
>
;
using
type
=
Sequence
<
Is
>
;
};
};
template
<
index_t
I
,
typename
F
>
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
0
,
F
>
struct
sequence_gen_impl
<
I
,
0
,
G
>
{
{
using
type
=
Sequence
<>
;
using
type
=
Sequence
<>
;
};
};
template
<
index_t
NSize
,
typename
F
>
struct
sequence_gen
{
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
};
};
...
@@ -281,8 +283,8 @@ struct sequence_split
...
@@ -281,8 +283,8 @@ struct sequence_split
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
SeqT
ype
0
=
decltype
(
Seq
::
Extract
(
range0
{}));
using
left_t
ype
=
decltype
(
Seq
::
Extract
(
range0
{}));
using
SeqT
ype
1
=
decltype
(
Seq
::
Extract
(
range1
{}));
using
right_t
ype
=
decltype
(
Seq
::
Extract
(
range1
{}));
};
};
// reverse sequence
// reverse sequence
...
@@ -293,8 +295,8 @@ struct sequence_reverse
...
@@ -293,8 +295,8 @@ struct sequence_reverse
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
type
=
typename
sequence_merge
<
using
type
=
typename
sequence_merge
<
typename
sequence_reverse
<
typename
seq_split
::
SeqT
ype
1
>::
type
,
typename
sequence_reverse
<
typename
seq_split
::
right_t
ype
>::
type
,
typename
sequence_reverse
<
typename
seq_split
::
SeqT
ype
0
>::
type
>::
type
;
typename
sequence_reverse
<
typename
seq_split
::
left_t
ype
>::
type
>::
type
;
};
};
template
<
index_t
I
>
template
<
index_t
I
>
...
@@ -309,138 +311,264 @@ struct sequence_reverse<Sequence<I0, I1>>
...
@@ -309,138 +311,264 @@ struct sequence_reverse<Sequence<I0, I1>>
using
type
=
Sequence
<
I1
,
I0
>
;
using
type
=
Sequence
<
I1
,
I0
>
;
};
};
template
<
typename
Seq
,
typename
Compare
>
template
<
typename
Values
,
typename
Ids
,
typename
Compare
>
struct
sequence_sort
struct
sequence_sort
_impl
{
{
template
<
typename
SeqLeft
,
typename
SeqRight
,
typename
MergedSeq
,
typename
Comp
>
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
struct
sorted_sequence_merge_impl
{
{
static
constexpr
bool
pick_left
=
SeqLeft
::
Front
()
<
SeqRight
::
Front
();
static
constexpr
bool
choose_left
=
LeftValues
::
Front
()
<
RightValues
::
Front
();
static
constexpr
index_t
next_value
=
pick_left
?
SeqLeft
::
Front
()
:
SeqRight
::
Front
();
static
constexpr
index_t
chosen_value
=
using
new_merged_seq
=
decltype
(
MergedSeq
::
PushBack
(
Number
<
next_value
>
{}));
choose_left
?
LeftValues
::
Front
()
:
RightValues
::
Front
();
static
constexpr
index_t
chosen_id
=
choose_left
?
LeftIds
::
Front
()
:
RightIds
::
Front
();
using
new_left_seq
=
typename
conditional
<
pick_left
,
decltype
(
SeqLeft
::
PopFront
()),
SeqLeft
>::
type
;
using
new_merged_values
=
decltype
(
MergedValues
::
PushBack
(
Number
<
chosen_value
>
{}));
using
new_right_seq
=
using
new_merged_ids
=
decltype
(
MergedIds
::
PushBack
(
Number
<
chosen_id
>
{}));
typename
conditional
<
pick_left
,
SeqRight
,
decltype
(
SeqRight
::
PopFront
())
>::
type
;
using
new_left_values
=
using
type
=
typename
conditional
<
choose_left
,
decltype
(
LeftValues
::
PopFront
()),
LeftValues
>::
type
;
typename
sorted_sequence_merge_impl
<
new_left_seq
,
new_right_seq
,
new_merged_seq
,
Comp
>::
using
new_left_ids
=
type
;
typename
conditional
<
choose_left
,
decltype
(
LeftIds
::
PopFront
()),
LeftIds
>::
type
;
using
new_right_values
=
typename
conditional
<
choose_left
,
RightValues
,
decltype
(
RightValues
::
PopFront
())
>::
type
;
using
new_right_ids
=
typename
conditional
<
choose_left
,
RightIds
,
decltype
(
RightIds
::
PopFront
())
>::
type
;
using
merge
=
sorted_sequence_merge_impl
<
new_left_values
,
new_left_ids
,
new_right_values
,
new_right_ids
,
new_merged_values
,
new_merged_ids
,
Comp
>
;
// this is output
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
};
template
<
typename
SeqLeft
,
typename
MergedSeq
,
typename
Comp
>
template
<
typename
LeftValues
,
struct
sorted_sequence_merge_impl
<
SeqLeft
,
Sequence
<>
,
MergedSeq
,
Comp
>
typename
LeftIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
Sequence
<>
,
Sequence
<>
,
MergedValues
,
MergedIds
,
Comp
>
{
{
using
type
=
typename
sequence_merge
<
MergedSeq
,
SeqLeft
>::
type
;
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
LeftValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
LeftIds
>::
type
;
};
};
template
<
typename
SeqRight
,
typename
MergedSeq
,
typename
Comp
>
template
<
typename
RightValues
,
struct
sorted_sequence_merge_impl
<
Sequence
<>
,
SeqRight
,
MergedSeq
,
Comp
>
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
Sequence
<>
,
Sequence
<>
,
RightValues
,
RightIds
,
MergedValues
,
MergedIds
,
Comp
>
{
{
using
type
=
typename
sequence_merge
<
MergedSeq
,
SeqRight
>::
type
;
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
RightValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
RightIds
>::
type
;
};
};
template
<
typename
Seq0
,
typename
Seq1
,
typename
Comp
>
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
Comp
>
struct
sorted_sequence_merge
struct
sorted_sequence_merge
{
{
using
type
=
typename
sorted_sequence_merge_impl
<
Seq0
,
Seq1
,
Sequence
<>
,
Comp
>::
type
;
using
merge
=
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
RightValues
,
RightIds
,
Sequence
<>
,
Sequence
<>
,
Comp
>
;
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
};
using
split
=
sequence_split
<
Seq
,
Seq
::
Size
()
/
2
>
;
static
constexpr
index_t
nsize
=
Values
::
Size
();
using
unsorted_left
=
typename
split
::
SeqType0
;
using
unsorted_right
=
typename
split
::
SeqType1
;
using
split_unsorted_values
=
sequence_split
<
Values
,
nsize
/
2
>
;
using
split_unsorted_ids
=
sequence_split
<
Ids
,
nsize
/
2
>
;
using
sorted_left
=
typename
sequence_sort
<
unsorted_left
,
Compare
>::
type
;
using
left_unsorted_values
=
typename
split_unsorted_values
::
left_type
;
using
sorted_right
=
typename
sequence_sort
<
unsorted_right
,
Compare
>::
type
;
using
left_unsorted_ids
=
typename
split_unsorted_ids
::
left_type
;
using
left_sort
=
sequence_sort_impl
<
left_unsorted_values
,
left_unsorted_ids
,
Compare
>
;
using
left_sorted_values
=
typename
left_sort
::
sorted_values
;
using
left_sorted_ids
=
typename
left_sort
::
sorted_ids
;
using
type
=
typename
sorted_sequence_merge
<
sorted_left
,
sorted_right
,
Compare
>::
type
;
using
right_unsorted_values
=
typename
split_unsorted_values
::
right_type
;
using
right_unsorted_ids
=
typename
split_unsorted_ids
::
right_type
;
using
right_sort
=
sequence_sort_impl
<
right_unsorted_values
,
right_unsorted_ids
,
Compare
>
;
using
right_sorted_values
=
typename
right_sort
::
sorted_values
;
using
right_sorted_ids
=
typename
right_sort
::
sorted_ids
;
using
merged_sorted
=
sorted_sequence_merge
<
left_sorted_values
,
left_sorted_ids
,
right_sorted_values
,
right_sorted_ids
,
Compare
>
;
using
sorted_values
=
typename
merged_sorted
::
merged_values
;
using
sorted_ids
=
typename
merged_sorted
::
merged_ids
;
};
};
template
<
index_t
X
,
index_t
Y
,
typename
Compare
>
template
<
index_t
Value
X
,
index_t
ValueY
,
index_t
IdX
,
index_t
Id
Y
,
typename
Compare
>
struct
sequence_sort
<
Sequence
<
X
,
Y
>
,
Compare
>
struct
sequence_sort
_impl
<
Sequence
<
ValueX
,
ValueY
>
,
Sequence
<
IdX
,
Id
Y
>
,
Compare
>
{
{
static
constexpr
bool
x_first
=
Compare
{}(
X
,
Y
);
static
constexpr
bool
choose_x
=
Compare
{}(
ValueX
,
ValueY
);
using
sorted_values
=
typename
conditional
<
choose_x
,
Sequence
<
ValueX
,
ValueY
>
,
Sequence
<
ValueY
,
ValueX
>>::
type
;
using
sorted_ids
=
typename
conditional
<
choose_x
,
Sequence
<
IdX
,
IdY
>
,
Sequence
<
IdY
,
IdX
>>::
type
;
};
using
type
=
typename
conditional
<
x_first
,
Sequence
<
X
,
Y
>
,
Sequence
<
Y
,
X
>>::
type
;
template
<
index_t
Value
,
index_t
Id
,
typename
Compare
>
struct
sequence_sort_impl
<
Sequence
<
Value
>
,
Sequence
<
Id
>
,
Compare
>
{
using
sorted_values
=
Sequence
<
Value
>
;
using
sorted_ids
=
Sequence
<
Id
>
;
};
};
template
<
index_t
X
,
typename
Compare
>
template
<
typename
Values
,
typename
Compare
>
struct
sequence_sort
<
Sequence
<
X
>
,
Compare
>
struct
sequence_sort
{
{
using
type
=
Sequence
<
X
>
;
using
unsorted_ids
=
typename
arithmetic_sequence_gen
<
0
,
Values
::
Size
(),
1
>::
type
;
using
sort
=
sequence_sort_impl
<
Values
,
unsorted_ids
,
Compare
>
;
// this is output
using
type
=
typename
sort
::
sorted_values
;
using
sorted2unsorted_map
=
typename
sort
::
sorted_ids
;
};
};
template
<
typename
Seq
,
typename
Less
,
typename
Equal
>
template
<
typename
Values
,
typename
Less
,
typename
Equal
>
struct
sequence_unique_sort
struct
sequence_unique_sort
{
{
template
<
typename
WorkInputSeq
,
typename
WorkOutputSeq
,
typename
Eq
>
template
<
typename
RemainValues
,
typename
RemainIds
,
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
struct
sorted_sequence_uniquify_impl
{
{
static
constexpr
index_t
new_value
=
WorkInputSeq
::
Front
();
static
constexpr
index_t
current_value
=
RemainValues
::
Front
();
using
new_work_input_seq
=
decltype
(
WorkInputSeq
::
PopFront
());
static
constexpr
index_t
current_id
=
RemainIds
::
Front
();
static
constexpr
bool
is_unique_value
=
(
current_value
!=
UniquifiedValues
::
Back
());
using
new_remain_values
=
decltype
(
RemainValues
::
PopFront
());
using
new_remain_ids
=
decltype
(
RemainIds
::
PopFront
());
using
new_uniquified_values
=
typename
conditional
<
is_unique_value
,
decltype
(
UniquifiedValues
::
PushBack
(
Number
<
current_value
>
{})),
UniquifiedValues
>::
type
;
using
new_working_output_seq
=
using
new_uniquified_ids
=
typename
conditional
<
new_value
==
WorkOutputSeq
::
Back
(),
typename
conditional
<
is_unique_value
,
WorkOutputSeq
,
decltype
(
UniquifiedIds
::
PushBack
(
Number
<
current_id
>
{})),
decltype
(
WorkOutputSeq
::
PopBack
(
Number
<
new_value
>
{}))
>::
type
;
UniquifiedIds
>::
type
;
using
uniquify
=
sorted_sequence_uniquify_impl
<
new_remain_values
,
new_remain_ids
,
new_uniquified_values
,
new_uniquified_ids
,
Eq
>
;
// this is output
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
};
template
<
typename
WorkInputSeq
,
typename
Eq
>
template
<
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
<
WorkInputSeq
,
Sequence
<>
,
Eq
>
struct
sorted_sequence_uniquify_impl
<
Sequence
<>
,
Sequence
<>
,
UniquifiedValues
,
UniquifiedIds
,
Eq
>
{
{
using
type
=
WorkInputSeq
;
using
uniquified_values
=
UniquifiedValues
;
using
uniquified_ids
=
UniquifiedIds
;
};
};
template
<
typename
Sorted
Seq
,
typename
Eq
>
template
<
typename
Sorted
Values
,
typename
SortedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify
struct
sorted_sequence_uniquify
{
{
using
type
=
typename
sorted_sequence_uniquify_impl
<
SortedSeq
,
Sequence
<>
,
Eq
>::
type
;
using
uniquify
=
sorted_sequence_uniquify_impl
<
decltype
(
SortedValues
::
PopFront
()),
decltype
(
SortedIds
::
PopFront
()),
Sequence
<
SortedValues
::
Front
()
>
,
Sequence
<
SortedIds
::
Front
()
>
,
Eq
>
;
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
};
using
sorted_seq
=
typename
sequence_sort
<
Seq
,
Less
>::
type
;
using
sort
=
sequence_sort
<
Values
,
Less
>
;
using
sorted_values
=
typename
sort
::
type
;
using
sorted_ids
=
typename
sort
::
sorted2unsorted_map
;
using
type
=
typename
sorted_sequence_uniquify
<
sorted_seq
,
Equal
>::
type
;
using
uniquify
=
sorted_sequence_uniquify
<
sorted_values
,
sorted_ids
,
Equal
>
;
// this is output
using
type
=
typename
uniquify
::
uniquified_values
;
using
sorted2unsorted_map
=
typename
uniquify
::
uniquified_ids
;
};
};
template
<
typename
Seq
>
template
<
typename
Seq
Map
>
struct
is_valid_sequence_map
struct
is_valid_sequence_map
{
{
// not implemented yet, always return true
static
constexpr
bool
value
=
static
constexpr
integral_constant
<
bool
,
true
>
value
=
integral_constant
<
bool
,
true
>
{};
is_same
<
typename
arithmetic_sequence_gen
<
0
,
SeqMap
::
Size
(),
1
>::
type
,
typename
sequence_sort
<
SeqMap
,
math
::
less
<
index_t
>>::
type
>
{};
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{};
};
};
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
template
<
typename
SeqMap
>
struct
sequence_map_inverse
_impl
struct
sequence_map_inverse
{
{
private:
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
::
At
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
struct
sequence_map_inverse_impl
{
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
::
At
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
public:
using
type
=
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
type
;
};
};
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
>
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
{
{
using
type
=
WorkingY2X
;
using
type
=
WorkingY2X
;
};
};
template
<
typename
X2Y
>
struct
sequence_map_inverse
{
using
type
=
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
typename
sequence_map_inverse_impl
<
SeqMap
,
typename
uniform_sequence_gen
<
X2Y
::
Size
(),
0
>::
type
,
typename
uniform_sequence_gen
<
SeqMap
::
Size
(),
0
>::
type
,
0
,
0
,
X2Y
::
Size
()
>::
type
;
SeqMap
::
Size
()
>::
type
;
};
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
...
@@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
...
@@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
}
}
template
<
typename
Seq
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
pick_sequence_elements
(
Seq
,
Sequence
<
Is
...
>
)
{
return
Sequence
<
Seq
::
At
(
Number
<
Is
>
{})...
>
{};
}
template
<
typename
Seq
,
typename
Reduce
>
template
<
typename
Seq
,
typename
Reduce
>
struct
lambda_accumulate_on_sequence
struct
lambda_accumulate_on_sequence
{
{
...
...
composable_kernel/include/utility/sequence_helper.hpp
View file @
ca42e910
#ifndef CK_SEQUENCE_HELPER_HPP
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "
S
equence.hpp"
#include "
s
equence.hpp"
namespace
ck
{
namespace
ck
{
template
<
index_t
...
Xs
>
template
<
index_t
...
Xs
>
__host__
__device__
void
print_
S
equence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
__host__
__device__
void
print_
s
equence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
{
{
constexpr
index_t
nsize
=
Sequence
<
Xs
...
>::
Size
();
constexpr
index_t
nsize
=
Sequence
<
Xs
...
>::
Size
();
...
...
composable_kernel/include/utility/tuple.hpp
View file @
ca42e910
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
#include "type.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
...
@@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
namespace
detail
{
namespace
detail
{
template
<
typename
X
,
typename
F
,
index_t
...
Is
>
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
trans
pose
_tuple_impl
(
X
&
x
,
F
f
,
Sequence
<
Is
...
>
)
__host__
__device__
constexpr
auto
trans
form
_tuple_impl
(
F
f
,
const
X
&
x
,
Sequence
<
Is
...
>
)
{
{
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}))...);
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}))...);
}
}
}
// namespace detail
}
// namespace detail
template
<
typename
X
,
typename
F
>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
trans
pose
_tuple
(
X
&
x
,
F
f
)
__host__
__device__
constexpr
auto
trans
form
_tuple
(
F
f
,
const
X
&
x
)
{
{
return
detail
::
trans
pose
_tuple_impl
(
return
detail
::
trans
form
_tuple_impl
(
x
,
f
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/type.hpp
View file @
ca42e910
...
@@ -2,10 +2,12 @@
...
@@ -2,10 +2,12 @@
#define CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "Sequence.hpp"
namespace
ck
{
namespace
ck
{
template
<
index_t
...
Is
>
struct
Sequence
;
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
{
...
...
driver/src/driver.cpp
View file @
ca42e910
...
@@ -84,8 +84,8 @@ int main(int argc, char* argv[])
...
@@ -84,8 +84,8 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
2
;
constexpr
index_t
HPad
=
3
;
constexpr
index_t
WPad
=
2
;
constexpr
index_t
WPad
=
3
;
#elif 1
#elif 1
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment