Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ca42e910
Commit
ca42e910
authored
Sep 10, 2019
by
Chao Liu
Browse files
adding merge transform
parent
7a7fe160
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
551 additions
and
253 deletions
+551
-253
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
..._convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
+44
-14
composable_kernel/include/tensor_description/dimension.hpp
composable_kernel/include/tensor_description/dimension.hpp
+2
-2
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+76
-37
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+101
-63
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+57
-2
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/array.hpp
+3
-4
composable_kernel/include/utility/array_helper.hpp
composable_kernel/include/utility/array_helper.hpp
+3
-3
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+2
-2
composable_kernel/include/utility/functional.hpp
composable_kernel/include/utility/functional.hpp
+1
-1
composable_kernel/include/utility/functional2.hpp
composable_kernel/include/utility/functional2.hpp
+1
-1
composable_kernel/include/utility/functional3.hpp
composable_kernel/include/utility/functional3.hpp
+2
-2
composable_kernel/include/utility/functional4.hpp
composable_kernel/include/utility/functional4.hpp
+2
-2
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+1
-0
composable_kernel/include/utility/sequence.hpp
composable_kernel/include/utility/sequence.hpp
+242
-108
composable_kernel/include/utility/sequence_helper.hpp
composable_kernel/include/utility/sequence_helper.hpp
+2
-2
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+7
-7
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+3
-1
driver/src/driver.cpp
driver/src/driver.cpp
+2
-2
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp
View file @
ca42e910
...
@@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
...
@@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
#elif
1
#elif
1
// create a native tensor descriptor
// create a native tensor descriptor
constexpr
auto
in_c_h_w_n_global_desc
=
constexpr
auto
in_c_h_w_n_global_desc
=
make_
N
ative
T
ensor
D
escriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
make_
n
ative
_t
ensor
_d
escriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_c_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_c_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_c_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_c_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
in_c_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
in_c_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
auto
pad_h_w
=
Pad
<
Sequence
<
Hi
,
Wi
>
,
LowerPads
,
UpperPads
>
{};
// transformation: {c, h, w, n} --> {n, c, hp, wp}
constexpr
auto
pass_c
=
PassThrough
<
C
>
{};
// {h, w} --> {hp, wp}, {c} --> {c}, {n} --> {n}
constexpr
auto
pass_n
=
PassThrough
<
N
>
{};
constexpr
auto
in_n_c_hp_wp_global_desc
=
transform_tensor_descriptor
(
in_c_h_w_n_global_desc
,
make_tuple
(
Pad
<
Sequence
<
Hi
,
Wi
>
,
LowerPads
,
UpperPads
>
{},
PassThrough
<
C
>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
1
,
2
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
constexpr
auto
trans
=
make_tuple
(
pass_c
,
pad_h_w
,
pass_n
);
#if 1
constexpr
auto
lower_dim_groups
=
// transformation: {n, c, hp, wp} --> {c, b}
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{});
// {n, hp, wp} --> {b}, {c} --> {c}
constexpr
auto
upper_dim_groups
=
constexpr
auto
in_c_b_global_desc
=
transform_tensor_descriptor
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{});
in_n_c_hp_wp_global_desc
,
make_tuple
(
Merge
<
decltype
(
in_n_c_hp_wp_global_desc
.
GetLengths
(
I0
,
I2
,
I3
))
>
{},
constexpr
auto
in_c_h_w_n_padded_global_desc
=
transform_tensor_descriptor
(
PassThrough
<
in_n_c_hp_wp_global_desc
.
GetLength
(
I1
)
>
{}),
in_c_h_w_n_global_desc
,
trans
,
lower_dim_groups
,
upper_dim_groups
);
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
#endif
#if 1
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
{
// 0
print_tensor_descriptor
(
"in_c_h_w_n_global_desc"
,
in_c_h_w_n_global_desc
);
print_tensor_descriptor
(
"in_c_h_w_n_global_desc"
,
in_c_h_w_n_global_desc
);
printf
(
"offset: %lu
\n
"
,
in_c_h_w_n_global_desc
.
GetOffset
({
1
,
2
,
3
,
4
}));
// 1
print_tensor_descriptor
(
"in_n_c_hp_wp_global_desc"
,
in_n_c_hp_wp_global_desc
);
// 2
print_tensor_descriptor
(
"in_c_b_global_desc"
,
in_c_b_global_desc
);
constexpr
auto
idx2
=
MultiIndex
<
2
>
{
1
,
4
*
(
16
*
16
)
+
5
*
16
+
6
};
auto
idx1
=
in_c_b_global_desc
.
CalculateLowerIndex
(
idx2
);
auto
idx0
=
in_c_b_global_desc
.
GetLowerTensorDescriptor
().
CalculateLowerIndex
(
idx1
);
printf
(
"padded offset: %lu
\n
"
,
in_c_h_w_n_padded_global_desc
.
GetOffset
({
1
,
4
,
5
,
4
}));
print_array
(
"idx2: "
,
idx2
);
print_array
(
"idx1: "
,
idx1
);
print_array
(
"idx0: "
,
idx0
);
printf
(
"in_c_b_global_desc offset: %lu
\n
"
,
in_c_b_global_desc
.
CalculateOffset
(
idx2
));
}
#else
{
index_t
c
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
index_t
h
=
static_cast
<
index_t
>
(
threadIdx
.
y
);
index_t
w
=
static_cast
<
index_t
>
(
threadIdx
.
z
);
p_out_global
[
0
]
=
in_n_c_h_w_padded_global_desc
.
CalculateOffset
({
1
,
c
,
h
,
w
});
}
}
#endif
#endif
#endif
}
}
#endif
#endif
...
...
composable_kernel/include/tensor_description/dimension.hpp
View file @
ca42e910
...
@@ -18,9 +18,9 @@ struct NativeDimension
...
@@ -18,9 +18,9 @@ struct NativeDimension
__host__
__device__
static
constexpr
auto
GetStride
()
{
return
Number
<
Stride
>
{};
}
__host__
__device__
static
constexpr
auto
GetStride
()
{
return
Number
<
Stride
>
{};
}
__host__
__device__
static
constexpr
index_t
Get
Offset
(
index_t
i
)
{
return
i
*
Stride
;
}
__host__
__device__
static
constexpr
index_t
Calculate
Offset
(
index_t
i
)
{
return
i
*
Stride
;
}
__host__
__device__
static
constexpr
index_t
Get
OffsetDiff
(
index_t
i_diff
)
__host__
__device__
static
constexpr
index_t
Calculate
OffsetDiff
(
index_t
i_diff
)
{
{
return
i_diff
*
Stride
;
return
i_diff
*
Stride
;
}
}
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
ca42e910
...
@@ -22,9 +22,12 @@ struct PassThrough
...
@@ -22,9 +22,12 @@ struct PassThrough
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerIndex
(
UpperIndex
idx_up
)
{
return
idx_up
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
UpperIndex
idx_up
)
{
return
idx_up
;
}
__host__
__device__
static
constexpr
auto
Get
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
{
{
return
idx_up_diff
;
return
idx_up_diff
;
}
}
...
@@ -36,7 +39,7 @@ struct PassThrough
...
@@ -36,7 +39,7 @@ struct PassThrough
template
<
typename
LowLengths
,
typename
LeftPads
,
typename
RightPads
>
template
<
typename
LowLengths
,
typename
LeftPads
,
typename
RightPads
>
struct
Pad
struct
Pad
{
{
static
constexpr
index_t
nDim
=
LowLengths
::
Get
Size
();
static
constexpr
index_t
nDim
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
...
@@ -52,12 +55,12 @@ struct Pad
...
@@ -52,12 +55,12 @@ struct Pad
return
GetLowerLengths
()
+
LeftPads
{}
+
RightPads
{};
return
GetLowerLengths
()
+
LeftPads
{}
+
RightPads
{};
}
}
__host__
__device__
static
constexpr
auto
Get
LowerIndex
(
UpperIndex
idx_up
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndex
(
UpperIndex
idx_up
)
{
{
return
idx_up
-
LeftPads
{};
return
idx_up
-
LeftPads
{};
}
}
__host__
__device__
static
constexpr
auto
Get
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
{
{
return
idx_up_diff
;
return
idx_up_diff
;
}
}
...
@@ -65,21 +68,20 @@ struct Pad
...
@@ -65,21 +68,20 @@ struct Pad
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
};
};
#if 0
// LowLengths: Sequence<...>
// LowLengths: Sequence<...>
template
<
typename
LowLengths
>
template
<
typename
LowLengths
>
struct
Merge
struct
Merge
{
{
static constexpr index_t nDimLow = LowLengths::
Get
Size();
static
constexpr
index_t
nDimLow
=
LowLengths
::
Size
();
static
constexpr
index_t
nDimUp
=
1
;
static
constexpr
index_t
nDimUp
=
1
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
LowLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
...
@@ -88,18 +90,56 @@ struct Merge
...
@@ -88,18 +90,56 @@ struct Merge
GetLowerLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
GetLowerLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
}
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
// emulate constexpr lambda
template
<
typename
PseudoLowStrides
>
struct
lambda_CalculateLowerIndex
{
index_t
&
itmp
;
LowerIndex
&
idx_low
;
__host__
__device__
explicit
constexpr
lambda_CalculateLowerIndex
(
index_t
&
itmp_
,
LowerIndex
&
idx_low_
)
:
itmp
(
itmp_
),
idx_low
(
idx_low_
)
{
}
template
<
typename
IDim
>
__host__
__device__
constexpr
void
operator
()(
IDim
idim
)
const
{
constexpr
index_t
stride
=
PseudoLowStrides
::
At
(
idim
);
idx_low
(
idim
)
=
itmp
/
stride
;
itmp
-=
idx_low
[
idim
]
*
stride
;
}
};
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
LowerIndex
idx_low
;
LowerIndex
idx_low
;
// not implemeneted
index_t
itmp
=
idx_up
[
0
];
constexpr
auto
pseudo_low_strides
=
reverse_inclusive_scan_sequence
(
GetLowerLengths
().
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
// calculate index in each of the dimensions in the order of their dimension
#if 1
static_for
<
0
,
nDimLow
-
1
,
1
>
{}(
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
idx_low
(
nDimLow
-
1
)
=
itmp
/
pseudo_low_strides
[
nDimLow
-
1
];
#else
static_for
<
0
,
nDimLow
,
1
>
{}(
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
#endif
return
idx_low
;
return
idx_low
;
}
}
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
__host__ __device__ static constexpr auto
Get
LowerIndexDiff(UpperIndex idx_up_diff,
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
LowerIndex idx_low_old)
const
LowerIndex
&
idx_low_old
)
{
{
LowerIndex
idx_low_diff
;
LowerIndex
idx_low_diff
;
...
@@ -110,49 +150,48 @@ struct Merge
...
@@ -110,49 +150,48 @@ struct Merge
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
};
};
#endif
// UpLengths: Sequence<...>
// UpLengths: Sequence<...>
template
<
index_t
LowLength
,
typename
UpLengths
>
template
<
typename
UpLengths
>
struct
Unmerge
struct
Unmerge
{
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpLengths
::
Get
Size
();
static
constexpr
index_t
nDimUp
=
UpLengths
::
Size
();
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
constexpr
Unmerge
()
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
{
static_assert
(
LowLength
==
accumulate_on_sequence
(
UpLengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{}),
"wrong! UpLengths need to be "
);
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
constexpr
index_t
low_length
=
accumulate_on_sequence
(
UpLengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
return
Sequence
<
LowLength
>
{};
}
return
Sequence
<
low_length
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
Get
LowerIndex
(
UpperIndex
idx_up
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
constexpr
auto
scans
=
typename
sequence_reverse_inclusive_scan
<
UpLengths
,
math
::
multiplies
<
index_t
>
,
1
>::
type
{};
LowerIndex
idx_low
{
0
};
LowerIndex
idx_low
{
0
};
static_for
<
0
,
nDimUp
,
1
>
{}([
&
](
auto
idim
)
{
idx_low
(
0
)
+=
idx_up
[
idim
]
*
scans
[
idim
];
});
constexpr
auto
pseudo_up_strides
=
typename
sequence_reverse_inclusive_scan
<
UpLengths
,
math
::
multiplies
<
index_t
>
,
1
>::
type
{};
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low
(
0
)
+=
idx_up
[
idim
]
*
pseudo_up_strides
[
idim
];
});
return
idx_low
;
return
idx_low
;
}
}
__host__
__device__
static
constexpr
auto
Get
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
)
{
{
return
Get
LowerIndex
(
idx_up_diff
);
return
Calculate
LowerIndex
(
idx_up_diff
);
}
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
...
@@ -165,12 +204,12 @@ template <index_t LowLength, typename UpLengths, typename Coefficients>
...
@@ -165,12 +204,12 @@ template <index_t LowLength, typename UpLengths, typename Coefficients>
struct
Embed
struct
Embed
{
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpLengths
::
Get
Size
();
static
constexpr
index_t
nDimUp
=
UpLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
constexpr
Embed
()
__host__
__device__
explicit
constexpr
Embed
()
{
{
static_assert
(
UpLengths
::
GetSize
()
==
nDimUp
&&
Coefficients
::
GetSize
()
==
nDimUp
+
1
,
static_assert
(
UpLengths
::
GetSize
()
==
nDimUp
&&
Coefficients
::
GetSize
()
==
nDimUp
+
1
,
"wrong! # of dimensions not consistent"
);
"wrong! # of dimensions not consistent"
);
...
@@ -191,7 +230,7 @@ struct Embed
...
@@ -191,7 +230,7 @@ struct Embed
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpLengths
{};
}
__host__
__device__
static
constexpr
auto
Get
LowerIndex
(
UpperIndex
idx_up
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
LowerIndex
idx_low
(
Coefficients
{}[
nDimUp
]);
LowerIndex
idx_low
(
Coefficients
{}[
nDimUp
]);
...
@@ -201,7 +240,7 @@ struct Embed
...
@@ -201,7 +240,7 @@ struct Embed
return
idx_low
;
return
idx_low
;
}
}
__host__
__device__
static
constexpr
auto
Get
LowerIndexDiff
(
UpperIndex
idx_up_diff
)
__host__
__device__
static
constexpr
auto
Calculate
LowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
)
{
{
LowerIndex
idx_low_diff
{
0
};
LowerIndex
idx_low_diff
{
0
};
...
...
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
ca42e910
...
@@ -18,47 +18,53 @@ struct NativeTensorDescriptor
...
@@ -18,47 +18,53 @@ struct NativeTensorDescriptor
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
struct
lambda_GetLength
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
{
template
<
typename
IDim
>
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetLength
();
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
}
{
return
GetLength
(
IDim
{});
}
};
__host__
__device__
static
constexpr
auto
GetLengths
()
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
{
return
typename
sequence_gen
<
nDim
,
lambda_GetLength
>::
type
{}
;
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetStride
()
;
}
}
struct
lambda_GetStride
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Sequence
<
IDims
...
>
)
{
{
template
<
typename
IDim
>
return
Sequence
<
GetLength
(
Number
<
IDims
>
{})...
>
{};
__host__
__device__
constexpr
auto
operator
()(
IDim
)
const
}
{
return
GetStride
(
IDim
{});
}
};
__host__
__device__
static
constexpr
auto
GetStrides
()
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetStrides
(
Sequence
<
IDims
...
>
)
{
{
return
typename
sequence_gen
<
nDim
,
lambda_GetStride
>::
type
{};
return
Sequence
<
GetStride
(
Number
<
IDims
>
{})...
>
{};
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetLength
s
(
Number
<
IDim
>
,
Number
<
IDims
>
...
)
{
{
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetLength
(
);
return
GetLengths
(
Sequence
<
IDim
,
IDims
...
>
{}
);
}
}
template
<
index_t
IDim
>
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
__host__
__device__
static
constexpr
auto
GetStride
s
(
Number
<
IDim
>
,
Number
<
IDims
>
...
)
{
{
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetStride
(
);
return
GetStrides
(
Sequence
<
IDim
,
IDims
...
>
{}
);
}
}
__host__
__device__
static
constexpr
index_t
GetOffset
(
const
Index
&
idx
)
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetLengths
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{});
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
GetStrides
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{});
}
__host__
__device__
static
constexpr
index_t
CalculateOffset
(
const
Index
&
idx
)
{
{
index_t
offset
=
0
;
index_t
offset
=
0
;
...
@@ -67,7 +73,7 @@ struct NativeTensorDescriptor
...
@@ -67,7 +73,7 @@ struct NativeTensorDescriptor
return
offset
;
return
offset
;
}
}
__host__
__device__
static
constexpr
index_t
Get
OffsetDiff
(
const
Index
&
idx_diff
)
__host__
__device__
static
constexpr
index_t
Calculate
OffsetDiff
(
const
Index
&
idx_diff
)
{
{
index_t
offset_diff
=
0
;
index_t
offset_diff
=
0
;
...
@@ -161,8 +167,10 @@ struct TransformedTensorDescriptor
...
@@ -161,8 +167,10 @@ struct TransformedTensorDescriptor
// UpDimensionIds should include all up-dimensions
// UpDimensionIds should include all up-dimensions
// TODO: sanity check: while a up-dimension could be associated with multille
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation,
// transformation, a low-dimension should be associated with only one transformation
// a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
}
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
...
@@ -170,49 +178,78 @@ struct TransformedTensorDescriptor
...
@@ -170,49 +178,78 @@ struct TransformedTensorDescriptor
return
GetNumOfUpperDimension
();
return
GetNumOfUpperDimension
();
}
}
#if 0
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
__host__ __device__ static constexpr auto GetUpperLengths()
{
return
LowTensorDescriptor
{};
}
__host__
__device__
static
constexpr
auto
GetLowerLengths
()
{
{
struct lambda_get_upper_lengths
return
GetLowerTensorDescriptor
().
GetLengths
();
}
struct
lambda_GetUpperLengths
{
template
<
typename
Transform
>
__host__
__device__
constexpr
auto
operator
()(
const
Transform
&
tran
)
const
{
{
template <typename Transform>
return
tran
.
GetUpperLengths
();
__host__ __device__ constexpr auto operator()(Transform tran) const
}
{
};
return tran.GetUpperLengths();
}
};
constexpr auto tuple_of_upper_lengths =
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
transform_tuple(Transforms, lambda_get_upper_lengths{});
{
constexpr
auto
tuple_of_up_lengths
=
transform_tuple
(
lambda_GetUpperLengths
{},
Transforms
{});
constexpr auto
all_upper
_lengths =
merge_tuple_of
_sequences
(
tuple_of_up
per
_lengths);
constexpr
auto
mingled_up
_lengths
=
unpack
(
lambda_merge
_sequences
{},
tuple_of_up_lengths
);
constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{});
constexpr
auto
mingled_up_dimension_ids
=
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{});
// TODO: sanity-check
all_upper
_dimension_ids contain all upper-dimensions
// TODO: sanity-check
mingled_up
_dimension_ids contain all upper-dimensions
// TODO: sanity-check
all_upper
_lengths have no conflicting upper-length
// TODO: sanity-check
mingled_up
_lengths have no conflicting upper-length
using sort_dimension_ids =
// sort by upper-dimension-ids
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
using
sort_up_dimension_ids
=
sequence_unique_sort
<
decltype
(
mingled_up_dimension_ids
),
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>
;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
static_assert
(
is_same
<
typename
sort_up_dimension_ids
::
type
,
typename
arithmetic_sequence_gen
<
0
,
nDimUp
,
1
>::
type
>
{},
"wrong! UpDimensionIds is not configured correctly"
);
constexpr auto sorted_upper_lengths =
constexpr
auto
sorted2unsorted_map
=
typename
sort_up_dimension_ids
::
sorted2unsorted_map
{};
sequence_element_pick(all_upper_lengths, sorted2unsorted_map);
return sorted_upper_lengths;
constexpr
auto
sorted_up_lengths
=
pick_sequence_elements
(
mingled_up_lengths
,
sorted2unsorted_map
);
return
sorted_up_lengths
;
}
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetUpperLengths
();
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetUpperLengths
();
}
#endif
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
{
return
LowTensorDescriptor
{};
return
GetLengths
()[
IDim
];
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Sequence
<
IDims
...
>
)
{
return
Sequence
<
GetLength
(
Number
<
IDims
>
{})...
>
{};
}
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Number
<
IDim
>
,
Number
<
IDims
>
...)
{
return
GetLengths
(
Sequence
<
IDim
,
IDims
...
>
{});
}
}
__host__
__device__
static
constexpr
LowerIndex
GetLowerIndex
(
const
UpperIndex
&
idx_up
)
// TODO: right now return value is constexpr because use of non-constepxr lambda
__host__
__device__
static
constexpr
LowerIndex
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
LowerIndex
idx_low
;
LowerIndex
idx_low
;
...
@@ -225,14 +262,15 @@ struct TransformedTensorDescriptor
...
@@ -225,14 +262,15 @@ struct TransformedTensorDescriptor
// this assume each lower (single) index is only assocaited with one transformation,
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
// of TransformedTensorDescriptor
idx_low_part
=
tran
.
Get
LowerIndex
(
to_array
(
idx_up_part
));
idx_low_part
=
tran
.
Calculate
LowerIndex
(
to_array
(
idx_up_part
));
});
});
return
idx_low
;
return
idx_low
;
}
}
__host__
__device__
static
constexpr
LowerIndex
GetLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
// TODO: right now return value is constexpr because use of non-constepxr lambda
const
LowerIndex
&
idx_low_old
)
__host__
__device__
static
constexpr
LowerIndex
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
LowerIndex
&
idx_low_old
)
{
{
LowerIndex
idx_low_diff
;
LowerIndex
idx_low_diff
;
...
@@ -250,15 +288,15 @@ struct TransformedTensorDescriptor
...
@@ -250,15 +288,15 @@ struct TransformedTensorDescriptor
// this assume each lower (single) index is associated with only one transformation,
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
// of TransformedTensorDescriptor
idx_low_diff_part
=
tran
.
Get
LowerIndex
(
idx_up_diff_part
,
idx_low_old_part
);
idx_low_diff_part
=
tran
.
Calculate
LowerIndex
(
idx_up_diff_part
,
idx_low_old_part
);
});
});
return
idx_low_diff
;
return
idx_low_diff
;
}
}
__host__
__device__
static
constexpr
index_t
Get
Offset
(
const
UpperIndex
&
idx_up
)
__host__
__device__
static
constexpr
index_t
Calculate
Offset
(
const
UpperIndex
&
idx_up
)
{
{
return
GetLowerTensorDescriptor
().
GetOffset
(
Get
LowerIndex
(
idx_up
));
return
GetLowerTensorDescriptor
().
CalculateOffset
(
Calculate
LowerIndex
(
idx_up
));
}
}
#if 0
#if 0
...
@@ -286,14 +324,14 @@ struct TransformedTensorDescriptor
...
@@ -286,14 +324,14 @@ struct TransformedTensorDescriptor
};
};
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
constexpr
auto
make_
N
ative
T
ensor
D
escriptor
(
Sequence
<
Lengths
...
>
,
__host__
__device__
constexpr
auto
make_
n
ative
_t
ensor
_d
escriptor
(
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
Sequence
<
Strides
...
>
)
{
{
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
}
}
template
<
typename
Lengths
>
template
<
typename
Lengths
>
__host__
__device__
constexpr
auto
make_
N
ative
T
ensor
D
escriptor_packed
(
Lengths
)
__host__
__device__
constexpr
auto
make_
n
ative
_t
ensor
_d
escriptor_packed
(
Lengths
)
{
{
constexpr
auto
strides
=
reverse_inclusive_scan_sequence
(
constexpr
auto
strides
=
reverse_inclusive_scan_sequence
(
Lengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
Lengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
...
...
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
View file @
ca42e910
...
@@ -7,12 +7,19 @@
...
@@ -7,12 +7,19 @@
namespace
ck
{
namespace
ck
{
template
<
typename
...
NativeDimensions
>
template
<
typename
...
NativeDimensions
>
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
__host__
__device__
void
NativeTensorDescriptor
<
NativeDimensions
...
>
desc
)
print_tensor_descriptor
(
const
char
*
s
,
const
NativeTensorDescriptor
<
NativeDimensions
...
>
&
desc
)
{
{
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
(),
desc
.
GetStrides
());
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
(),
desc
.
GetStrides
());
}
}
template
<
typename
...
Ts
>
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
const
TransformedTensorDescriptor
<
Ts
...
>&
desc
)
{
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
());
}
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
void
__host__
__device__
void
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
...
@@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strid
...
@@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strid
});
});
}
}
template
<
index_t
...
Lengths
>
__host__
__device__
void
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
)
{
constexpr
index_t
nDim
=
sizeof
...(
Lengths
);
static_assert
(
nDim
>
0
&&
nDim
<=
12
,
"wrong!"
);
static_if
<
nDim
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
2
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
3
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
4
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
5
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
6
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u},
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
7
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
11
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
12
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/
A
rray.hpp
→
composable_kernel/include/utility/
a
rray.hpp
View file @
ca42e910
#ifndef CK_ARRAY_HPP
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
#define CK_ARRAY_HPP
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -17,7 +17,7 @@ struct Array
...
@@ -17,7 +17,7 @@ struct Array
__host__
__device__
explicit
constexpr
Array
()
{}
__host__
__device__
explicit
constexpr
Array
()
{}
template
<
typename
X
,
typename
...
Xs
>
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
explicit
constexpr
Array
(
X
x
,
Xs
...
xs
)
__host__
__device__
constexpr
Array
(
X
x
,
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
x
),
static_cast
<
TData
>
(
xs
)...}
:
mData
{
static_cast
<
TData
>
(
x
),
static_cast
<
TData
>
(
xs
)...}
{
{
static_assert
(
sizeof
...(
Xs
)
+
1
==
NSize
,
"wrong! size"
);
static_assert
(
sizeof
...(
Xs
)
+
1
==
NSize
,
"wrong! size"
);
...
@@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
...
@@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
}
}
#if 1
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
auto
to_array
(
const
T
&
x
)
__host__
__device__
constexpr
auto
to_array
(
const
T
&
x
)
{
{
...
@@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x)
...
@@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x)
return
y
;
return
y
;
}
}
#endif
// TODO: remove this
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
{
{
...
...
composable_kernel/include/utility/array_helper.hpp
View file @
ca42e910
#ifndef CK_ARRAY_HELPER_HPP
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "
A
rray.hpp"
#include "
a
rray.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
T
,
index_t
NSize
>
template
<
typename
T
,
index_t
NSize
>
__host__
__device__
void
print_
A
rray
(
const
char
*
s
,
Array
<
T
,
NSize
>
a
)
__host__
__device__
void
print_
a
rray
(
const
char
*
s
,
Array
<
T
,
NSize
>
a
)
{
{
constexpr
index_t
nsize
=
a
.
GetSize
();
constexpr
index_t
nsize
=
a
.
GetSize
();
...
@@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
...
@@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
}
}
}
// namespace ck
}
// namespace ck
#endif
#endif
\ No newline at end of file
composable_kernel/include/utility/common_header.hpp
View file @
ca42e910
...
@@ -9,9 +9,9 @@
...
@@ -9,9 +9,9 @@
#include "tuple.hpp"
#include "tuple.hpp"
#include "math.hpp"
#include "math.hpp"
#include "vector_type.hpp"
#include "vector_type.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "sequence_helper.hpp"
#include "sequence_helper.hpp"
#include "
A
rray.hpp"
#include "
a
rray.hpp"
#include "array_helper.hpp"
#include "array_helper.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
...
...
composable_kernel/include/utility/functional.hpp
View file @
ca42e910
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#define CK_FUNCTIONAL_HPP
#define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "type.hpp"
#include "type.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/functional2.hpp
View file @
ca42e910
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#define CK_FUNCTIONAL2_HPP
#define CK_FUNCTIONAL2_HPP
#include "functional.hpp"
#include "functional.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/functional3.hpp
View file @
ca42e910
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "
A
rray.hpp"
#include "
a
rray.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/functional4.hpp
View file @
ca42e910
#ifndef CK_FUNCTIONAL4_HPP
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
#include "
A
rray.hpp"
#include "
a
rray.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/utility/math.hpp
View file @
ca42e910
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "config.hpp"
#include "config.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
namespace
ck
{
namespace
ck
{
namespace
math
{
namespace
math
{
...
...
composable_kernel/include/utility/
S
equence.hpp
→
composable_kernel/include/utility/
s
equence.hpp
View file @
ca42e910
...
@@ -2,7 +2,9 @@
...
@@ -2,7 +2,9 @@
#define CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "math.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -155,8 +157,8 @@ struct Sequence
...
@@ -155,8 +157,8 @@ struct Sequence
static_assert
(
I
<
Size
(),
"wrong!"
);
static_assert
(
I
<
Size
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
Type
,
I
>
;
using
seq_split
=
sequence_split
<
Type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
SeqT
ype
0
{};
constexpr
auto
seq_left
=
typename
seq_split
::
left_t
ype
{};
constexpr
auto
seq_right
=
typename
seq_split
::
SeqT
ype
1
{}.
PopFront
();
constexpr
auto
seq_right
=
typename
seq_split
::
right_t
ype
{}.
PopFront
();
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
}
...
@@ -188,34 +190,34 @@ struct sequence_merge<Seq>
...
@@ -188,34 +190,34 @@ struct sequence_merge<Seq>
};
};
// generate sequence
// generate sequence
template
<
index_t
IBegin
,
index_t
NRemain
,
typename
F
>
template
<
index_t
NSize
,
typename
F
>
struct
sequence_gen
_impl
struct
sequence_gen
{
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
template
<
index_t
IBegin
,
index_t
NRemain
,
typename
G
>
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
struct
sequence_gen_impl
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
using
type
=
using
type
=
typename
sequence_merge
<
typename
sequence_merge
<
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
F
>::
type
,
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
G
>::
type
,
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
F
>::
type
>::
type
;
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
G
>::
type
>::
type
;
};
};
template
<
index_t
I
,
typename
F
>
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
1
,
F
>
struct
sequence_gen_impl
<
I
,
1
,
G
>
{
{
static
constexpr
index_t
Is
=
F
{}(
Number
<
I
>
{});
static
constexpr
index_t
Is
=
G
{}(
Number
<
I
>
{});
using
type
=
Sequence
<
Is
>
;
using
type
=
Sequence
<
Is
>
;
};
};
template
<
index_t
I
,
typename
F
>
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
0
,
F
>
struct
sequence_gen_impl
<
I
,
0
,
G
>
{
{
using
type
=
Sequence
<>
;
using
type
=
Sequence
<>
;
};
};
template
<
index_t
NSize
,
typename
F
>
struct
sequence_gen
{
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
};
};
...
@@ -281,8 +283,8 @@ struct sequence_split
...
@@ -281,8 +283,8 @@ struct sequence_split
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
SeqT
ype
0
=
decltype
(
Seq
::
Extract
(
range0
{}));
using
left_t
ype
=
decltype
(
Seq
::
Extract
(
range0
{}));
using
SeqT
ype
1
=
decltype
(
Seq
::
Extract
(
range1
{}));
using
right_t
ype
=
decltype
(
Seq
::
Extract
(
range1
{}));
};
};
// reverse sequence
// reverse sequence
...
@@ -293,8 +295,8 @@ struct sequence_reverse
...
@@ -293,8 +295,8 @@ struct sequence_reverse
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
type
=
typename
sequence_merge
<
using
type
=
typename
sequence_merge
<
typename
sequence_reverse
<
typename
seq_split
::
SeqT
ype
1
>::
type
,
typename
sequence_reverse
<
typename
seq_split
::
right_t
ype
>::
type
,
typename
sequence_reverse
<
typename
seq_split
::
SeqT
ype
0
>::
type
>::
type
;
typename
sequence_reverse
<
typename
seq_split
::
left_t
ype
>::
type
>::
type
;
};
};
template
<
index_t
I
>
template
<
index_t
I
>
...
@@ -309,138 +311,264 @@ struct sequence_reverse<Sequence<I0, I1>>
...
@@ -309,138 +311,264 @@ struct sequence_reverse<Sequence<I0, I1>>
using
type
=
Sequence
<
I1
,
I0
>
;
using
type
=
Sequence
<
I1
,
I0
>
;
};
};
template
<
typename
Seq
,
typename
Compare
>
template
<
typename
Values
,
typename
Ids
,
typename
Compare
>
struct
sequence_sort
struct
sequence_sort
_impl
{
{
template
<
typename
SeqLeft
,
typename
SeqRight
,
typename
MergedSeq
,
typename
Comp
>
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
struct
sorted_sequence_merge_impl
{
{
static
constexpr
bool
pick_left
=
SeqLeft
::
Front
()
<
SeqRight
::
Front
();
static
constexpr
bool
choose_left
=
LeftValues
::
Front
()
<
RightValues
::
Front
();
static
constexpr
index_t
next_value
=
pick_left
?
SeqLeft
::
Front
()
:
SeqRight
::
Front
();
static
constexpr
index_t
chosen_value
=
using
new_merged_seq
=
decltype
(
MergedSeq
::
PushBack
(
Number
<
next_value
>
{}));
choose_left
?
LeftValues
::
Front
()
:
RightValues
::
Front
();
static
constexpr
index_t
chosen_id
=
choose_left
?
LeftIds
::
Front
()
:
RightIds
::
Front
();
using
new_left_seq
=
typename
conditional
<
pick_left
,
decltype
(
SeqLeft
::
PopFront
()),
SeqLeft
>::
type
;
using
new_merged_values
=
decltype
(
MergedValues
::
PushBack
(
Number
<
chosen_value
>
{}));
using
new_right_seq
=
using
new_merged_ids
=
decltype
(
MergedIds
::
PushBack
(
Number
<
chosen_id
>
{}));
typename
conditional
<
pick_left
,
SeqRight
,
decltype
(
SeqRight
::
PopFront
())
>::
type
;
using
new_left_values
=
using
type
=
typename
conditional
<
choose_left
,
decltype
(
LeftValues
::
PopFront
()),
LeftValues
>::
type
;
typename
sorted_sequence_merge_impl
<
new_left_seq
,
new_right_seq
,
new_merged_seq
,
Comp
>::
using
new_left_ids
=
type
;
typename
conditional
<
choose_left
,
decltype
(
LeftIds
::
PopFront
()),
LeftIds
>::
type
;
using
new_right_values
=
typename
conditional
<
choose_left
,
RightValues
,
decltype
(
RightValues
::
PopFront
())
>::
type
;
using
new_right_ids
=
typename
conditional
<
choose_left
,
RightIds
,
decltype
(
RightIds
::
PopFront
())
>::
type
;
using
merge
=
sorted_sequence_merge_impl
<
new_left_values
,
new_left_ids
,
new_right_values
,
new_right_ids
,
new_merged_values
,
new_merged_ids
,
Comp
>
;
// this is output
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
};
template
<
typename
SeqLeft
,
typename
MergedSeq
,
typename
Comp
>
template
<
typename
LeftValues
,
struct
sorted_sequence_merge_impl
<
SeqLeft
,
Sequence
<>
,
MergedSeq
,
Comp
>
typename
LeftIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
Sequence
<>
,
Sequence
<>
,
MergedValues
,
MergedIds
,
Comp
>
{
{
using
type
=
typename
sequence_merge
<
MergedSeq
,
SeqLeft
>::
type
;
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
LeftValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
LeftIds
>::
type
;
};
};
template
<
typename
SeqRight
,
typename
MergedSeq
,
typename
Comp
>
template
<
typename
RightValues
,
struct
sorted_sequence_merge_impl
<
Sequence
<>
,
SeqRight
,
MergedSeq
,
Comp
>
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
Sequence
<>
,
Sequence
<>
,
RightValues
,
RightIds
,
MergedValues
,
MergedIds
,
Comp
>
{
{
using
type
=
typename
sequence_merge
<
MergedSeq
,
SeqRight
>::
type
;
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
RightValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
RightIds
>::
type
;
};
};
template
<
typename
Seq0
,
typename
Seq1
,
typename
Comp
>
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
Comp
>
struct
sorted_sequence_merge
struct
sorted_sequence_merge
{
{
using
type
=
typename
sorted_sequence_merge_impl
<
Seq0
,
Seq1
,
Sequence
<>
,
Comp
>::
type
;
using
merge
=
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
RightValues
,
RightIds
,
Sequence
<>
,
Sequence
<>
,
Comp
>
;
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
};
using
split
=
sequence_split
<
Seq
,
Seq
::
Size
()
/
2
>
;
static
constexpr
index_t
nsize
=
Values
::
Size
();
using
unsorted_left
=
typename
split
::
SeqType0
;
using
unsorted_right
=
typename
split
::
SeqType1
;
using
split_unsorted_values
=
sequence_split
<
Values
,
nsize
/
2
>
;
using
split_unsorted_ids
=
sequence_split
<
Ids
,
nsize
/
2
>
;
using
sorted_left
=
typename
sequence_sort
<
unsorted_left
,
Compare
>::
type
;
using
left_unsorted_values
=
typename
split_unsorted_values
::
left_type
;
using
sorted_right
=
typename
sequence_sort
<
unsorted_right
,
Compare
>::
type
;
using
left_unsorted_ids
=
typename
split_unsorted_ids
::
left_type
;
using
left_sort
=
sequence_sort_impl
<
left_unsorted_values
,
left_unsorted_ids
,
Compare
>
;
using
left_sorted_values
=
typename
left_sort
::
sorted_values
;
using
left_sorted_ids
=
typename
left_sort
::
sorted_ids
;
using
type
=
typename
sorted_sequence_merge
<
sorted_left
,
sorted_right
,
Compare
>::
type
;
using
right_unsorted_values
=
typename
split_unsorted_values
::
right_type
;
using
right_unsorted_ids
=
typename
split_unsorted_ids
::
right_type
;
using
right_sort
=
sequence_sort_impl
<
right_unsorted_values
,
right_unsorted_ids
,
Compare
>
;
using
right_sorted_values
=
typename
right_sort
::
sorted_values
;
using
right_sorted_ids
=
typename
right_sort
::
sorted_ids
;
using
merged_sorted
=
sorted_sequence_merge
<
left_sorted_values
,
left_sorted_ids
,
right_sorted_values
,
right_sorted_ids
,
Compare
>
;
using
sorted_values
=
typename
merged_sorted
::
merged_values
;
using
sorted_ids
=
typename
merged_sorted
::
merged_ids
;
};
};
template
<
index_t
X
,
index_t
Y
,
typename
Compare
>
template
<
index_t
Value
X
,
index_t
ValueY
,
index_t
IdX
,
index_t
Id
Y
,
typename
Compare
>
struct
sequence_sort
<
Sequence
<
X
,
Y
>
,
Compare
>
struct
sequence_sort
_impl
<
Sequence
<
ValueX
,
ValueY
>
,
Sequence
<
IdX
,
Id
Y
>
,
Compare
>
{
{
static
constexpr
bool
x_first
=
Compare
{}(
X
,
Y
);
static
constexpr
bool
choose_x
=
Compare
{}(
ValueX
,
ValueY
);
using
sorted_values
=
typename
conditional
<
choose_x
,
Sequence
<
ValueX
,
ValueY
>
,
Sequence
<
ValueY
,
ValueX
>>::
type
;
using
sorted_ids
=
typename
conditional
<
choose_x
,
Sequence
<
IdX
,
IdY
>
,
Sequence
<
IdY
,
IdX
>>::
type
;
};
using
type
=
typename
conditional
<
x_first
,
Sequence
<
X
,
Y
>
,
Sequence
<
Y
,
X
>>::
type
;
template
<
index_t
Value
,
index_t
Id
,
typename
Compare
>
struct
sequence_sort_impl
<
Sequence
<
Value
>
,
Sequence
<
Id
>
,
Compare
>
{
using
sorted_values
=
Sequence
<
Value
>
;
using
sorted_ids
=
Sequence
<
Id
>
;
};
};
template
<
index_t
X
,
typename
Compare
>
template
<
typename
Values
,
typename
Compare
>
struct
sequence_sort
<
Sequence
<
X
>
,
Compare
>
struct
sequence_sort
{
{
using
type
=
Sequence
<
X
>
;
using
unsorted_ids
=
typename
arithmetic_sequence_gen
<
0
,
Values
::
Size
(),
1
>::
type
;
using
sort
=
sequence_sort_impl
<
Values
,
unsorted_ids
,
Compare
>
;
// this is output
using
type
=
typename
sort
::
sorted_values
;
using
sorted2unsorted_map
=
typename
sort
::
sorted_ids
;
};
};
template
<
typename
Seq
,
typename
Less
,
typename
Equal
>
template
<
typename
Values
,
typename
Less
,
typename
Equal
>
struct
sequence_unique_sort
struct
sequence_unique_sort
{
{
template
<
typename
WorkInputSeq
,
typename
WorkOutputSeq
,
typename
Eq
>
template
<
typename
RemainValues
,
typename
RemainIds
,
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
struct
sorted_sequence_uniquify_impl
{
{
static
constexpr
index_t
new_value
=
WorkInputSeq
::
Front
();
static
constexpr
index_t
current_value
=
RemainValues
::
Front
();
using
new_work_input_seq
=
decltype
(
WorkInputSeq
::
PopFront
());
static
constexpr
index_t
current_id
=
RemainIds
::
Front
();
static
constexpr
bool
is_unique_value
=
(
current_value
!=
UniquifiedValues
::
Back
());
using
new_remain_values
=
decltype
(
RemainValues
::
PopFront
());
using
new_remain_ids
=
decltype
(
RemainIds
::
PopFront
());
using
new_uniquified_values
=
typename
conditional
<
is_unique_value
,
decltype
(
UniquifiedValues
::
PushBack
(
Number
<
current_value
>
{})),
UniquifiedValues
>::
type
;
using
new_working_output_seq
=
using
new_uniquified_ids
=
typename
conditional
<
new_value
==
WorkOutputSeq
::
Back
(),
typename
conditional
<
is_unique_value
,
WorkOutputSeq
,
decltype
(
UniquifiedIds
::
PushBack
(
Number
<
current_id
>
{})),
decltype
(
WorkOutputSeq
::
PopBack
(
Number
<
new_value
>
{}))
>::
type
;
UniquifiedIds
>::
type
;
using
uniquify
=
sorted_sequence_uniquify_impl
<
new_remain_values
,
new_remain_ids
,
new_uniquified_values
,
new_uniquified_ids
,
Eq
>
;
// this is output
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
};
template
<
typename
WorkInputSeq
,
typename
Eq
>
template
<
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
<
WorkInputSeq
,
Sequence
<>
,
Eq
>
struct
sorted_sequence_uniquify_impl
<
Sequence
<>
,
Sequence
<>
,
UniquifiedValues
,
UniquifiedIds
,
Eq
>
{
{
using
type
=
WorkInputSeq
;
using
uniquified_values
=
UniquifiedValues
;
using
uniquified_ids
=
UniquifiedIds
;
};
};
template
<
typename
Sorted
Seq
,
typename
Eq
>
template
<
typename
Sorted
Values
,
typename
SortedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify
struct
sorted_sequence_uniquify
{
{
using
type
=
typename
sorted_sequence_uniquify_impl
<
SortedSeq
,
Sequence
<>
,
Eq
>::
type
;
using
uniquify
=
sorted_sequence_uniquify_impl
<
decltype
(
SortedValues
::
PopFront
()),
decltype
(
SortedIds
::
PopFront
()),
Sequence
<
SortedValues
::
Front
()
>
,
Sequence
<
SortedIds
::
Front
()
>
,
Eq
>
;
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
};
using
sorted_seq
=
typename
sequence_sort
<
Seq
,
Less
>::
type
;
using
sort
=
sequence_sort
<
Values
,
Less
>
;
using
sorted_values
=
typename
sort
::
type
;
using
sorted_ids
=
typename
sort
::
sorted2unsorted_map
;
using
type
=
typename
sorted_sequence_uniquify
<
sorted_seq
,
Equal
>::
type
;
using
uniquify
=
sorted_sequence_uniquify
<
sorted_values
,
sorted_ids
,
Equal
>
;
// this is output
using
type
=
typename
uniquify
::
uniquified_values
;
using
sorted2unsorted_map
=
typename
uniquify
::
uniquified_ids
;
};
};
template
<
typename
Seq
>
template
<
typename
Seq
Map
>
struct
is_valid_sequence_map
struct
is_valid_sequence_map
{
{
// not implemented yet, always return true
static
constexpr
bool
value
=
static
constexpr
integral_constant
<
bool
,
true
>
value
=
integral_constant
<
bool
,
true
>
{};
is_same
<
typename
arithmetic_sequence_gen
<
0
,
SeqMap
::
Size
(),
1
>::
type
,
typename
sequence_sort
<
SeqMap
,
math
::
less
<
index_t
>>::
type
>
{};
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{};
};
};
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
template
<
typename
SeqMap
>
struct
sequence_map_inverse
_impl
struct
sequence_map_inverse
{
{
private:
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
::
At
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
struct
sequence_map_inverse_impl
{
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
::
At
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
public:
using
type
=
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
type
;
};
};
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
>
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
{
{
using
type
=
WorkingY2X
;
using
type
=
WorkingY2X
;
};
};
template
<
typename
X2Y
>
struct
sequence_map_inverse
{
using
type
=
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
typename
sequence_map_inverse_impl
<
SeqMap
,
typename
uniform_sequence_gen
<
X2Y
::
Size
(),
0
>::
type
,
typename
uniform_sequence_gen
<
SeqMap
::
Size
(),
0
>::
type
,
0
,
0
,
X2Y
::
Size
()
>::
type
;
SeqMap
::
Size
()
>::
type
;
};
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
...
@@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
...
@@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
}
}
template
<
typename
Seq
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
pick_sequence_elements
(
Seq
,
Sequence
<
Is
...
>
)
{
return
Sequence
<
Seq
::
At
(
Number
<
Is
>
{})...
>
{};
}
template
<
typename
Seq
,
typename
Reduce
>
template
<
typename
Seq
,
typename
Reduce
>
struct
lambda_accumulate_on_sequence
struct
lambda_accumulate_on_sequence
{
{
...
...
composable_kernel/include/utility/sequence_helper.hpp
View file @
ca42e910
#ifndef CK_SEQUENCE_HELPER_HPP
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "
S
equence.hpp"
#include "
s
equence.hpp"
namespace
ck
{
namespace
ck
{
template
<
index_t
...
Xs
>
template
<
index_t
...
Xs
>
__host__
__device__
void
print_
S
equence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
__host__
__device__
void
print_
s
equence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
{
{
constexpr
index_t
nsize
=
Sequence
<
Xs
...
>::
Size
();
constexpr
index_t
nsize
=
Sequence
<
Xs
...
>::
Size
();
...
...
composable_kernel/include/utility/tuple.hpp
View file @
ca42e910
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
#include "type.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
...
@@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
namespace
detail
{
namespace
detail
{
template
<
typename
X
,
typename
F
,
index_t
...
Is
>
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
trans
pose
_tuple_impl
(
X
&
x
,
F
f
,
Sequence
<
Is
...
>
)
__host__
__device__
constexpr
auto
trans
form
_tuple_impl
(
F
f
,
const
X
&
x
,
Sequence
<
Is
...
>
)
{
{
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}))...);
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}))...);
}
}
}
// namespace detail
}
// namespace detail
template
<
typename
X
,
typename
F
>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
trans
pose
_tuple
(
X
&
x
,
F
f
)
__host__
__device__
constexpr
auto
trans
form
_tuple
(
F
f
,
const
X
&
x
)
{
{
return
detail
::
trans
pose
_tuple_impl
(
return
detail
::
trans
form
_tuple_impl
(
x
,
f
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/type.hpp
View file @
ca42e910
...
@@ -2,10 +2,12 @@
...
@@ -2,10 +2,12 @@
#define CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "Sequence.hpp"
namespace
ck
{
namespace
ck
{
template
<
index_t
...
Is
>
struct
Sequence
;
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
{
...
...
driver/src/driver.cpp
View file @
ca42e910
...
@@ -84,8 +84,8 @@ int main(int argc, char* argv[])
...
@@ -84,8 +84,8 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
2
;
constexpr
index_t
HPad
=
3
;
constexpr
index_t
WPad
=
2
;
constexpr
index_t
WPad
=
3
;
#elif 1
#elif 1
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment