Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9d59a39a
Commit
9d59a39a
authored
Jun 17, 2019
by
Chao Liu
Browse files
refactoring
parent
33d1e0e2
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
76 additions
and
89 deletions
+76
-89
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
...ridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
.../gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
.../gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+13
-12
composable_kernel/include/utility/Array.hpp
composable_kernel/include/utility/Array.hpp
+1
-1
composable_kernel/include/utility/Sequence.hpp
composable_kernel/include/utility/Sequence.hpp
+51
-65
composable_kernel/include/utility/functional2.hpp
composable_kernel/include/utility/functional2.hpp
+1
-1
composable_kernel/include/utility/utility.hpp
composable_kernel/include/utility/utility.hpp
+4
-4
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
View file @
9d59a39a
...
@@ -440,7 +440,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -440,7 +440,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqT
ype
{},
arithmetic_sequence_gen
<
0
,
10
,
1
>::
t
ype
{},
Number
<
1
>
{});
Number
<
1
>
{});
#endif
#endif
});
});
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
View file @
9d59a39a
...
@@ -491,7 +491,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -491,7 +491,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqT
ype
{},
arithmetic_sequence_gen
<
0
,
10
,
1
>::
t
ype
{},
Number
<
1
>
{});
Number
<
1
>
{});
#endif
#endif
});
});
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
View file @
9d59a39a
...
@@ -367,7 +367,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -367,7 +367,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
p_out_thread_on_global
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
SeqT
ype
{},
arithmetic_sequence_gen
<
0
,
8
,
1
>::
t
ype
{},
Number
<
1
>
{});
Number
<
1
>
{});
}
}
}
}
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
View file @
9d59a39a
...
@@ -394,7 +394,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -394,7 +394,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
p_out_thread_on_global
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
SeqT
ype
{},
arithmetic_sequence_gen
<
0
,
8
,
1
>::
t
ype
{},
Number
<
1
>
{});
Number
<
1
>
{});
}
}
}
}
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
9d59a39a
...
@@ -344,7 +344,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
...
@@ -344,7 +344,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
p_out_thread_on_global
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
SeqT
ype
{},
arithmetic_sequence_gen
<
0
,
8
,
1
>::
t
ype
{},
Number
<
1
>
{});
Number
<
1
>
{});
}
}
}
}
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
9d59a39a
...
@@ -398,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -398,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
p_out_thread_on_global
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
SeqT
ype
{},
arithmetic_sequence_gen
<
0
,
8
,
1
>::
t
ype
{},
Number
<
1
>
{});
Number
<
1
>
{});
}
}
}
}
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
9d59a39a
...
@@ -305,8 +305,9 @@ struct ConstantTensorDescriptor
...
@@ -305,8 +305,9 @@ struct ConstantTensorDescriptor
{
{
using
leaf_tensor
=
ConstantTensorDescriptor
<
Ts
...
>
;
using
leaf_tensor
=
ConstantTensorDescriptor
<
Ts
...
>
;
return
ConstantTensorDescriptor
<
decltype
(
GetLengths
().
Append
(
leaf_tensor
::
GetLengths
())),
return
ConstantTensorDescriptor
<
decltype
(
GetLengths
().
PushBack
(
leaf_tensor
::
GetLengths
())),
decltype
(
GetStrides
().
Append
(
leaf_tensor
::
GetStrides
()))
>
{};
decltype
(
GetStrides
().
PushBack
(
leaf_tensor
::
GetStrides
()))
>
{};
}
}
template
<
index_t
IDim
,
index_t
SliceLen
>
template
<
index_t
IDim
,
index_t
SliceLen
>
...
@@ -347,7 +348,7 @@ struct ConstantTensorDescriptor
...
@@ -347,7 +348,7 @@ struct ConstantTensorDescriptor
// folded lengths
// folded lengths
constexpr
auto
fold_lengths
=
constexpr
auto
fold_lengths
=
Sequence
<
unfold_length
/
fold_intervals_product
>
{}.
Append
(
fold_intervals
);
Sequence
<
unfold_length
/
fold_intervals_product
>
{}.
PushBack
(
fold_intervals
);
// folded strides
// folded strides
constexpr
auto
fold_strides
=
constexpr
auto
fold_strides
=
...
@@ -356,14 +357,14 @@ struct ConstantTensorDescriptor
...
@@ -356,14 +357,14 @@ struct ConstantTensorDescriptor
fold_intervals
.
PushBack
(
Number
<
1
>
{}),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
fold_intervals
.
PushBack
(
Number
<
1
>
{}),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
// left and right
// left and right
constexpr
auto
left
=
typename
arithmetic_sequence_gen
<
0
,
IDim
,
1
>::
SeqT
ype
{};
constexpr
auto
left
=
typename
arithmetic_sequence_gen
<
0
,
IDim
,
1
>::
t
ype
{};
constexpr
auto
right
=
constexpr
auto
right
=
typename
arithmetic_sequence_gen
<
IDim
+
1
,
GetNumOfDimension
(),
1
>::
SeqT
ype
{};
typename
arithmetic_sequence_gen
<
IDim
+
1
,
GetNumOfDimension
(),
1
>::
t
ype
{};
constexpr
auto
new_lengths
=
constexpr
auto
new_lengths
=
GetLengths
().
Extract
(
left
).
Append
(
fold_lengths
).
Append
(
GetLengths
().
Extract
(
right
));
GetLengths
().
Extract
(
left
).
PushBack
(
fold_lengths
).
PushBack
(
GetLengths
().
Extract
(
right
));
constexpr
auto
new_strides
=
constexpr
auto
new_strides
=
GetStrides
().
Extract
(
left
).
Append
(
fold_strides
).
Append
(
GetStrides
().
Extract
(
right
));
GetStrides
().
Extract
(
left
).
PushBack
(
fold_strides
).
PushBack
(
GetStrides
().
Extract
(
right
));
return
ConstantTensorDescriptor
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
return
ConstantTensorDescriptor
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
}
}
...
@@ -377,11 +378,11 @@ struct ConstantTensorDescriptor
...
@@ -377,11 +378,11 @@ struct ConstantTensorDescriptor
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!"
);
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!"
);
// left and right
// left and right
constexpr
auto
left
=
typename
arithmetic_sequence_gen
<
0
,
FirstUnfoldDim
,
1
>::
SeqT
ype
{};
constexpr
auto
left
=
typename
arithmetic_sequence_gen
<
0
,
FirstUnfoldDim
,
1
>::
t
ype
{};
constexpr
auto
middle
=
constexpr
auto
middle
=
typename
arithmetic_sequence_gen
<
FirstUnfoldDim
,
LastUnfoldDim
+
1
,
1
>::
SeqT
ype
{};
typename
arithmetic_sequence_gen
<
FirstUnfoldDim
,
LastUnfoldDim
+
1
,
1
>::
t
ype
{};
constexpr
auto
right
=
constexpr
auto
right
=
typename
arithmetic_sequence_gen
<
LastUnfoldDim
+
1
,
GetNumOfDimension
(),
1
>::
SeqT
ype
{};
typename
arithmetic_sequence_gen
<
LastUnfoldDim
+
1
,
GetNumOfDimension
(),
1
>::
t
ype
{};
// dimensions to be unfolded need to be continuous
// dimensions to be unfolded need to be continuous
static_assert
(
Type
::
Extract
(
middle
).
AreDimensionsContinuous
(),
"wrong! not unfoldable"
);
static_assert
(
Type
::
Extract
(
middle
).
AreDimensionsContinuous
(),
"wrong! not unfoldable"
);
...
@@ -396,12 +397,12 @@ struct ConstantTensorDescriptor
...
@@ -396,12 +397,12 @@ struct ConstantTensorDescriptor
constexpr
auto
new_lengths
=
GetLengths
()
constexpr
auto
new_lengths
=
GetLengths
()
.
Extract
(
left
)
.
Extract
(
left
)
.
PushBack
(
Number
<
unfold_length
>
{})
.
PushBack
(
Number
<
unfold_length
>
{})
.
Append
(
GetLengths
().
Extract
(
right
));
.
PushBack
(
GetLengths
().
Extract
(
right
));
constexpr
auto
new_strides
=
GetStrides
()
constexpr
auto
new_strides
=
GetStrides
()
.
Extract
(
left
)
.
Extract
(
left
)
.
PushBack
(
Number
<
unfold_stride
>
{})
.
PushBack
(
Number
<
unfold_stride
>
{})
.
Append
(
GetStrides
().
Extract
(
right
));
.
PushBack
(
GetStrides
().
Extract
(
right
));
return
ConstantTensorDescriptor
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
return
ConstantTensorDescriptor
<
decltype
(
new_lengths
),
decltype
(
new_strides
)
>
{};
}
}
...
...
composable_kernel/include/utility/Array.hpp
View file @
9d59a39a
...
@@ -87,7 +87,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
...
@@ -87,7 +87,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
template
<
class
TData
,
index_t
NSize
>
template
<
class
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
make_zero_array
()
__host__
__device__
constexpr
auto
make_zero_array
()
{
{
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
SeqT
ype
{};
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
t
ype
{};
constexpr
auto
zero_array
=
sequence2array
(
zero_sequence
);
constexpr
auto
zero_array
=
sequence2array
(
zero_sequence
);
return
zero_array
;
return
zero_array
;
}
}
...
...
composable_kernel/include/utility/Sequence.hpp
View file @
9d59a39a
...
@@ -29,12 +29,9 @@ struct Sequence
...
@@ -29,12 +29,9 @@ struct Sequence
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
index_t
operator
[](
Number
<
I
>
)
const
__host__
__device__
constexpr
auto
operator
[](
Number
<
I
>
)
const
{
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
return
Number
<
Get
(
Number
<
I
>
{})
>
{};
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
}
}
// make sure I is constepxr
// make sure I is constepxr
...
@@ -69,24 +66,30 @@ struct Sequence
...
@@ -69,24 +66,30 @@ struct Sequence
return
mData
[
mSize
-
1
];
return
mData
[
mSize
-
1
];
}
}
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
PopFront
();
__host__
__device__
static
constexpr
auto
PushFront
(
Number
<
I
>
)
__host__
__device__
static
constexpr
auto
PopBack
();
template
<
index_t
...
Xs
>
__host__
__device__
static
constexpr
auto
PushFront
(
Sequence
<
Xs
...
>
)
{
{
return
Sequence
<
I
,
Is
...
>
{};
return
Sequence
<
Xs
...
,
Is
...
>
{};
}
}
template
<
index_t
I
>
template
<
index_t
...
Xs
>
__host__
__device__
static
constexpr
auto
Push
Back
(
Number
<
I
>
)
__host__
__device__
static
constexpr
auto
Push
Front
(
Number
<
Xs
>
...
)
{
{
return
Sequence
<
I
s
...,
I
>
{};
return
Sequence
<
X
s
...,
I
s
...
>
{};
}
}
__host__
__device__
static
constexpr
auto
PopFront
();
template
<
index_t
...
Xs
>
__host__
__device__
static
constexpr
auto
PushBack
(
Sequence
<
Xs
...
>
)
__host__
__device__
static
constexpr
auto
PopBack
();
{
return
Sequence
<
Is
...,
Xs
...
>
{};
}
template
<
index_t
...
Xs
>
template
<
index_t
...
Xs
>
__host__
__device__
static
constexpr
auto
Append
(
Sequence
<
Xs
...
>
)
__host__
__device__
static
constexpr
auto
PushBack
(
Number
<
Xs
>
...)
{
{
return
Sequence
<
Is
...,
Xs
...
>
{};
return
Sequence
<
Is
...,
Xs
...
>
{};
}
}
...
@@ -105,6 +108,12 @@ struct Sequence
...
@@ -105,6 +108,12 @@ struct Sequence
template
<
index_t
I
,
index_t
X
>
template
<
index_t
I
,
index_t
X
>
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
);
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
);
template
<
class
F
>
__host__
__device__
static
constexpr
auto
Transform
(
F
f
)
{
return
Sequence
<
f
(
Is
)...
>
{};
}
};
};
// merge sequence
// merge sequence
...
@@ -114,7 +123,7 @@ struct sequence_merge;
...
@@ -114,7 +123,7 @@ struct sequence_merge;
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
struct
sequence_merge
<
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
struct
sequence_merge
<
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
{
{
using
SeqT
ype
=
Sequence
<
Xs
...,
Ys
...
>
;
using
t
ype
=
Sequence
<
Xs
...,
Ys
...
>
;
};
};
// arithmetic sqeuence
// arithmetic sqeuence
...
@@ -123,40 +132,29 @@ struct arithmetic_sequence_gen_impl
...
@@ -123,40 +132,29 @@ struct arithmetic_sequence_gen_impl
{
{
static
constexpr
index_t
NSizeLeft
=
NSize
/
2
;
static
constexpr
index_t
NSizeLeft
=
NSize
/
2
;
using
SeqT
ype
=
typename
sequence_merge
<
using
t
ype
=
typename
sequence_merge
<
typename
arithmetic_sequence_gen_impl
<
IBegin
,
NSizeLeft
,
Increment
>::
SeqT
ype
,
typename
arithmetic_sequence_gen_impl
<
IBegin
,
NSizeLeft
,
Increment
>::
t
ype
,
typename
arithmetic_sequence_gen_impl
<
IBegin
+
NSizeLeft
*
Increment
,
typename
arithmetic_sequence_gen_impl
<
IBegin
+
NSizeLeft
*
Increment
,
NSize
-
NSizeLeft
,
NSize
-
NSizeLeft
,
Increment
>::
SeqT
ype
>::
SeqT
ype
;
Increment
>::
t
ype
>::
t
ype
;
};
};
template
<
index_t
IBegin
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
Increment
>
struct
arithmetic_sequence_gen_impl
<
IBegin
,
1
,
Increment
>
struct
arithmetic_sequence_gen_impl
<
IBegin
,
1
,
Increment
>
{
{
using
SeqT
ype
=
Sequence
<
IBegin
>
;
using
t
ype
=
Sequence
<
IBegin
>
;
};
};
template
<
index_t
IBegin
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
Increment
>
struct
arithmetic_sequence_gen_impl
<
IBegin
,
0
,
Increment
>
struct
arithmetic_sequence_gen_impl
<
IBegin
,
0
,
Increment
>
{
{
using
SeqT
ype
=
Sequence
<>
;
using
t
ype
=
Sequence
<>
;
};
};
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
struct
arithmetic_sequence_gen
struct
arithmetic_sequence_gen
{
{
using
SeqType
=
using
type
=
typename
arithmetic_sequence_gen_impl
<
IBegin
,
IEnd
-
IBegin
,
Increment
>::
type
;
typename
arithmetic_sequence_gen_impl
<
IBegin
,
IEnd
-
IBegin
,
Increment
>::
SeqType
;
};
// transform sequence
template
<
class
,
class
>
struct
sequence_transform
;
template
<
class
F
,
index_t
...
Is
>
struct
sequence_transform
<
F
,
Sequence
<
Is
...
>>
{
using
SeqType
=
Sequence
<
F
{}(
Is
)...
>
;
};
};
// uniform sequence
// uniform sequence
...
@@ -168,9 +166,8 @@ struct uniform_sequence_gen
...
@@ -168,9 +166,8 @@ struct uniform_sequence_gen
__host__
__device__
constexpr
index_t
operator
()(
index_t
)
const
{
return
I
;
}
__host__
__device__
constexpr
index_t
operator
()(
index_t
)
const
{
return
I
;
}
};
};
using
SeqType
=
typename
sequence_transform
<
using
type
=
decltype
(
return_constant
,
typename
arithmetic_sequence_gen
<
0
,
NSize
,
1
>::
type
{}.
Transform
(
return_constant
{}));
typename
arithmetic_sequence_gen
<
0
,
NSize
,
1
>::
SeqType
>::
SeqType
;
};
};
// reverse inclusive scan (with init) sequence
// reverse inclusive scan (with init) sequence
...
@@ -180,34 +177,23 @@ struct sequence_reverse_inclusive_scan;
...
@@ -180,34 +177,23 @@ struct sequence_reverse_inclusive_scan;
template
<
index_t
I
,
index_t
...
Is
,
class
Reduce
,
index_t
Init
>
template
<
index_t
I
,
index_t
...
Is
,
class
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
,
Is
...
>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
,
Is
...
>
,
Reduce
,
Init
>
{
{
using
old_scan
=
using
old_scan
=
typename
sequence_reverse_inclusive_scan
<
Sequence
<
Is
...
>
,
Reduce
,
Init
>::
type
;
typename
sequence_reverse_inclusive_scan
<
Sequence
<
Is
...
>
,
Reduce
,
Init
>::
SeqType
;
static
constexpr
index_t
new_reduce
=
Reduce
{}(
I
,
old_scan
{}.
Front
());
static
constexpr
index_t
new_reduce
=
Reduce
{}(
I
,
old_scan
{}.
Front
());
using
SeqT
ype
=
typename
sequence_merge
<
Sequence
<
new_reduce
>
,
old_scan
>::
SeqT
ype
;
using
t
ype
=
typename
sequence_merge
<
Sequence
<
new_reduce
>
,
old_scan
>::
t
ype
;
};
};
template
<
index_t
I
,
class
Reduce
,
index_t
Init
>
template
<
index_t
I
,
class
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
>
,
Reduce
,
Init
>
{
{
using
SeqT
ype
=
Sequence
<
Reduce
{}(
I
,
Init
)
>
;
using
t
ype
=
Sequence
<
Reduce
{}(
I
,
Init
)
>
;
};
};
template
<
class
Reduce
,
index_t
Init
>
template
<
class
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<>
,
Reduce
,
Init
>
{
{
using
SeqType
=
Sequence
<>
;
using
type
=
Sequence
<>
;
};
// extract sequence
template
<
class
,
class
>
struct
sequence_extract
;
template
<
class
Seq
,
index_t
...
Is
>
struct
sequence_extract
<
Seq
,
Sequence
<
Is
...
>>
{
using
SeqType
=
Sequence
<
Seq
{}.
Get
(
Number
<
Is
>
{})...
>
;
};
};
// split sequence
// split sequence
...
@@ -216,11 +202,11 @@ struct sequence_split
...
@@ -216,11 +202,11 @@ struct sequence_split
{
{
static
constexpr
index_t
NSize
=
Seq
{}.
GetSize
();
static
constexpr
index_t
NSize
=
Seq
{}.
GetSize
();
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
SeqT
ype
;
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
t
ype
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
SeqT
ype
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
t
ype
;
using
SeqType0
=
typename
sequence_extract
<
Seq
,
range0
>::
SeqType
;
using
SeqType0
=
decltype
(
Seq
::
Extract
(
range0
{}))
;
using
SeqType1
=
typename
sequence_extract
<
Seq
,
range1
>::
SeqType
;
using
SeqType1
=
decltype
(
Seq
::
Extract
(
range1
{}))
;
};
};
// reverse sequence
// reverse sequence
...
@@ -230,31 +216,31 @@ struct sequence_reverse
...
@@ -230,31 +216,31 @@ struct sequence_reverse
static
constexpr
index_t
NSize
=
Seq
{}.
GetSize
();
static
constexpr
index_t
NSize
=
Seq
{}.
GetSize
();
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
SeqType
=
typename
sequence_merge
<
using
type
=
typename
sequence_merge
<
typename
sequence_reverse
<
typename
seq_split
::
SeqType1
>::
SeqT
ype
,
typename
sequence_reverse
<
typename
seq_split
::
SeqType1
>::
t
ype
,
typename
sequence_reverse
<
typename
seq_split
::
SeqType0
>::
SeqT
ype
>::
SeqT
ype
;
typename
sequence_reverse
<
typename
seq_split
::
SeqType0
>::
t
ype
>::
t
ype
;
};
};
template
<
index_t
I
>
template
<
index_t
I
>
struct
sequence_reverse
<
Sequence
<
I
>>
struct
sequence_reverse
<
Sequence
<
I
>>
{
{
using
SeqT
ype
=
Sequence
<
I
>
;
using
t
ype
=
Sequence
<
I
>
;
};
};
template
<
index_t
I0
,
index_t
I1
>
template
<
index_t
I0
,
index_t
I1
>
struct
sequence_reverse
<
Sequence
<
I0
,
I1
>>
struct
sequence_reverse
<
Sequence
<
I0
,
I1
>>
{
{
using
SeqT
ype
=
Sequence
<
I1
,
I0
>
;
using
t
ype
=
Sequence
<
I1
,
I0
>
;
};
};
template
<
class
Seq
>
template
<
class
Seq
>
struct
is_valid_sequence_map
struct
is_valid_sequence_map
{
{
static
constexpr
bool
value
=
true
;
static
constexpr
integral_constant
<
bool
,
true
>
value
=
integral_constant
<
bool
,
true
>
{}
;
// TODO: add proper check for is_valid, something like:
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::
SeqT
ype,
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::
t
ype,
// typename sequence_sort<Seq>::SortedSeqType>{};
// typename sequence_sort<Seq>::SortedSeqType>{};
};
};
...
@@ -401,7 +387,7 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
...
@@ -401,7 +387,7 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
auto
reverse_inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
__host__
__device__
constexpr
auto
reverse_inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
{
{
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
,
Init
>::
SeqT
ype
{};
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
,
Init
>::
t
ype
{};
}
}
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
...
@@ -425,7 +411,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack()
...
@@ -425,7 +411,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack()
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Reverse
()
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Reverse
()
{
{
return
typename
sequence_reverse
<
Sequence
<
Is
...
>>::
SeqT
ype
{};
return
typename
sequence_reverse
<
Sequence
<
Is
...
>>::
t
ype
{};
}
}
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
...
@@ -438,7 +424,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
...
@@ -438,7 +424,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
constexpr
auto
seq_left
=
typename
seq_split
::
SeqType0
{};
constexpr
auto
seq_left
=
typename
seq_split
::
SeqType0
{};
constexpr
auto
seq_right
=
typename
seq_split
::
SeqType1
{}.
PopFront
();
constexpr
auto
seq_right
=
typename
seq_split
::
SeqType1
{}.
PopFront
();
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
Append
(
seq_right
);
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
}
template
<
index_t
...
Xs
>
template
<
index_t
...
Xs
>
...
...
composable_kernel/include/utility/functional2.hpp
View file @
9d59a39a
...
@@ -31,7 +31,7 @@ struct static_for
...
@@ -31,7 +31,7 @@ struct static_for
static_assert
((
NEnd
-
NBegin
)
%
Increment
==
0
,
static_assert
((
NEnd
-
NBegin
)
%
Increment
==
0
,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
SeqT
ype
>
{}(
f
);
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
t
ype
>
{}(
f
);
}
}
};
};
...
...
composable_kernel/include/utility/utility.hpp
View file @
9d59a39a
...
@@ -59,9 +59,9 @@ __host__ __device__ constexpr T integer_divide_ceil(T a, T b)
...
@@ -59,9 +59,9 @@ __host__ __device__ constexpr T integer_divide_ceil(T a, T b)
}
}
template
<
class
T
>
template
<
class
T
>
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
__host__
__device__
constexpr
T
max
(
T
x
)
{
{
return
x
>
y
?
x
:
y
;
return
x
;
}
}
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
...
@@ -77,9 +77,9 @@ __host__ __device__ constexpr T max(T x, Ts... xs)
...
@@ -77,9 +77,9 @@ __host__ __device__ constexpr T max(T x, Ts... xs)
}
}
template
<
class
T
>
template
<
class
T
>
__host__
__device__
constexpr
T
min
(
T
x
,
T
y
)
__host__
__device__
constexpr
T
min
(
T
x
)
{
{
return
x
<
y
?
x
:
y
;
return
x
;
}
}
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment