Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
52423948
Commit
52423948
authored
Sep 27, 2019
by
Jehandad Khan
Browse files
Merge branch 'master' into jd_redux
parents
b97af4ec
98a2cfcc
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2066 additions
and
496 deletions
+2066
-496
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/array.hpp
+150
-124
composable_kernel/include/utility/array_helper.hpp
composable_kernel/include/utility/array_helper.hpp
+177
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+12
-2
composable_kernel/include/utility/config_amd.hpp.in
composable_kernel/include/utility/config_amd.hpp.in
+25
-7
composable_kernel/include/utility/config_nvidia.hpp.in
composable_kernel/include/utility/config_nvidia.hpp.in
+25
-3
composable_kernel/include/utility/functional.hpp
composable_kernel/include/utility/functional.hpp
+44
-6
composable_kernel/include/utility/functional2.hpp
composable_kernel/include/utility/functional2.hpp
+7
-31
composable_kernel/include/utility/functional3.hpp
composable_kernel/include/utility/functional3.hpp
+30
-40
composable_kernel/include/utility/functional4.hpp
composable_kernel/include/utility/functional4.hpp
+34
-0
composable_kernel/include/utility/integral_constant.hpp
composable_kernel/include/utility/integral_constant.hpp
+0
-46
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+19
-0
composable_kernel/include/utility/number.hpp
composable_kernel/include/utility/number.hpp
+44
-0
composable_kernel/include/utility/sequence.hpp
composable_kernel/include/utility/sequence.hpp
+867
-0
composable_kernel/include/utility/sequence_helper.hpp
composable_kernel/include/utility/sequence_helper.hpp
+46
-0
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+159
-0
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+43
-0
composable_kernel/include/utility/vector_type.hpp
composable_kernel/include/utility/vector_type.hpp
+19
-1
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...de/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+100
-154
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
...ce_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
+188
-0
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+77
-82
No files found.
composable_kernel/include/utility/
A
rray.hpp
→
composable_kernel/include/utility/
a
rray.hpp
View file @
52423948
#ifndef CK_ARRAY_HPP
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
#define CK_ARRAY_HPP
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
namespace
ck
{
namespace
ck
{
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
struct
Array
struct
Array
{
{
using
T
ype
=
Array
<
TData
,
NSize
>
;
using
t
ype
=
Array
<
TData
,
NSize
>
;
using
data_type
=
TData
;
using
data_type
=
TData
;
static
constexpr
index_t
nSize
=
NSize
;
index_t
mData
[
NSize
]
;
index_t
mData
[
nSize
];
__host__
__device__
explicit
constexpr
Array
()
{}
template
<
class
...
Xs
>
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
Array
(
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
xs
)...}
__host__
__device__
constexpr
Array
(
X
x
,
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
x
),
static_cast
<
TData
>
(
xs
)...}
{
{
static_assert
(
sizeof
...(
Xs
)
+
1
==
NSize
,
"wrong! size"
);
}
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
NSize
;
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
NSize
;
}
// TODO: remove
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
Size
();
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
TData
operator
[]
(
Number
<
I
>
)
const
__host__
__device__
constexpr
const
TData
&
At
(
Number
<
I
>
)
const
{
{
static_assert
(
I
<
NSize
,
"wrong!"
);
return
mData
[
I
];
return
mData
[
I
];
}
}
__host__
__device__
constexpr
TData
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
TData
&
operator
()
(
Number
<
I
>
)
__host__
__device__
constexpr
TData
&
At
(
Number
<
I
>
)
{
{
static_assert
(
I
<
NSize
,
"wrong!"
);
return
mData
[
I
];
return
mData
[
I
];
}
}
__host__
__device__
TData
&
operator
()
(
index_t
i
)
{
return
mData
[
i
];
}
__host__
__device__
constexpr
const
TData
&
At
(
index_t
i
)
const
{
return
mData
[
i
];
}
template
<
index_t
I
>
__host__
__device__
constexpr
TData
&
At
(
index_t
i
)
{
return
mData
[
i
];
}
__host__
__device__
constexpr
void
Set
(
Number
<
I
>
,
TData
x
)
template
<
typename
I
>
__host__
__device__
constexpr
const
TData
&
operator
[](
I
i
)
const
{
{
static_assert
(
I
<
NSize
,
"wrong!"
);
return
At
(
i
);
}
mData
[
I
]
=
x
;
template
<
typename
I
>
__host__
__device__
constexpr
TData
&
operator
()(
I
i
)
{
return
At
(
i
);
}
}
__host__
__device__
constexpr
void
Set
(
index_t
I
,
TData
x
)
{
mData
[
I
]
=
x
;
}
template
<
typename
T
>
__host__
__device__
constexpr
type
&
operator
=
(
const
T
&
x
)
{
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
x
[
i
];
});
return
*
this
;
}
struct
lambda_PushBack
// emulate constexpr lambda
struct
lambda_PushBack
// emulate constexpr lambda
{
{
...
@@ -63,7 +82,7 @@ struct Array
...
@@ -63,7 +82,7 @@ struct Array
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
void
operator
()(
Number
<
I
>
)
const
__host__
__device__
constexpr
void
operator
()(
Number
<
I
>
)
const
{
{
new_array
.
Set
(
Number
<
I
>
{}
,
old_array
[
I
]
)
;
new_array
(
Number
<
I
>
{}
)
=
old_array
[
I
];
}
}
};
};
...
@@ -73,19 +92,96 @@ struct Array
...
@@ -73,19 +92,96 @@ struct Array
static_for
<
0
,
NSize
,
1
>
{}(
lambda_PushBack
(
*
this
,
new_array
));
static_for
<
0
,
NSize
,
1
>
{}(
lambda_PushBack
(
*
this
,
new_array
));
new_array
.
Set
(
Number
<
NSize
>
{}
,
x
)
;
new_array
(
Number
<
NSize
>
{}
)
=
x
;
return
new_array
;
return
new_array
;
}
}
};
};
// Arr: Array
// Picks: Sequence<...>
template
<
typename
Arr
,
typename
Picks
>
struct
ArrayElementPicker
{
using
type
=
ArrayElementPicker
;
using
data_type
=
typename
Arr
::
data_type
;
__host__
__device__
constexpr
ArrayElementPicker
()
=
delete
;
__host__
__device__
explicit
constexpr
ArrayElementPicker
(
Arr
&
array
)
:
mArray
{
array
}
{
constexpr
index_t
imax
=
reduce_on_sequence
(
Picks
{},
math
::
maxer
<
index_t
>
{},
Number
<
0
>
{});
static_assert
(
imax
<
Arr
::
Size
(),
"wrong! exceeding # array element"
);
}
__host__
__device__
static
constexpr
auto
Size
()
{
return
Picks
::
Size
();
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
data_type
&
At
(
Number
<
I
>
)
const
{
static_assert
(
I
<
Size
(),
"wrong!"
);
constexpr
auto
IP
=
Picks
{}[
I
];
return
mArray
[
IP
];
}
template
<
index_t
I
>
__host__
__device__
constexpr
data_type
&
At
(
Number
<
I
>
)
{
static_assert
(
I
<
Size
(),
"wrong!"
);
constexpr
auto
IP
=
Picks
{}[
I
];
return
mArray
(
IP
);
}
template
<
typename
I
>
__host__
__device__
constexpr
const
data_type
&
operator
[](
I
i
)
const
{
return
At
(
i
);
}
template
<
typename
I
>
__host__
__device__
constexpr
data_type
&
operator
()(
I
i
)
{
return
At
(
i
);
}
template
<
typename
T
>
__host__
__device__
constexpr
type
&
operator
=
(
const
T
&
a
)
{
static_for
<
0
,
Size
(),
1
>
{}([
&
](
auto
i
)
{
operator
()(
i
)
=
a
[
i
];
});
return
*
this
;
}
Arr
&
mArray
;
};
template
<
typename
Arr
,
typename
Picks
>
__host__
__device__
constexpr
auto
pick_array_element
(
Arr
&
a
,
Picks
)
{
return
ArrayElementPicker
<
Arr
,
Picks
>
(
a
);
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
to_array
(
const
T
&
x
)
{
Array
<
typename
T
::
data_type
,
T
::
Size
()
>
y
;
static_for
<
0
,
T
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
y
.
At
(
i
)
=
x
.
At
(
i
);
});
return
y
;
}
// TODO: remove this
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
__host__
__device__
constexpr
auto
sequence2array
(
Sequence
<
Is
...
>
)
{
{
return
Array
<
index_t
,
sizeof
...(
Is
)
>
{
Is
...};
return
Array
<
index_t
,
sizeof
...(
Is
)
>
{
Is
...};
}
}
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
make_zero_array
()
__host__
__device__
constexpr
auto
make_zero_array
()
{
{
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
type
{};
constexpr
auto
zero_sequence
=
typename
uniform_sequence_gen
<
NSize
,
0
>::
type
{};
...
@@ -93,7 +189,7 @@ __host__ __device__ constexpr auto make_zero_array()
...
@@ -93,7 +189,7 @@ __host__ __device__ constexpr auto make_zero_array()
return
zero_array
;
return
zero_array
;
}
}
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
constexpr
auto
reorder_array_given_new2old
(
const
Array
<
TData
,
NSize
>&
old_array
,
__host__
__device__
constexpr
auto
reorder_array_given_new2old
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
/*new2old*/
)
Sequence
<
IRs
...
>
/*new2old*/
)
{
{
...
@@ -104,7 +200,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
...
@@ -104,7 +200,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return
Array
<
TData
,
NSize
>
{
old_array
[
IRs
]...};
return
Array
<
TData
,
NSize
>
{
old_array
[
IRs
]...};
}
}
template
<
class
TData
,
index_t
NSize
,
class
MapOld2New
>
template
<
typename
TData
,
index_t
NSize
,
typename
MapOld2New
>
struct
lambda_reorder_array_given_old2new
struct
lambda_reorder_array_given_old2new
{
{
const
Array
<
TData
,
NSize
>&
old_array
;
const
Array
<
TData
,
NSize
>&
old_array
;
...
@@ -121,13 +217,13 @@ struct lambda_reorder_array_given_old2new
...
@@ -121,13 +217,13 @@ struct lambda_reorder_array_given_old2new
{
{
TData
old_data
=
old_array
[
IOldDim
];
TData
old_data
=
old_array
[
IOldDim
];
constexpr
index_t
INewDim
=
MapOld2New
::
Ge
t
(
Number
<
IOldDim
>
{});
constexpr
index_t
INewDim
=
MapOld2New
::
A
t
(
Number
<
IOldDim
>
{});
new_array
.
Set
(
Number
<
INewDim
>
{}
,
old_data
)
;
new_array
(
Number
<
INewDim
>
{}
)
=
old_data
;
}
}
};
};
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
constexpr
auto
reorder_array_given_old2new
(
const
Array
<
TData
,
NSize
>&
old_array
,
__host__
__device__
constexpr
auto
reorder_array_given_old2new
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
/*old2new*/
)
Sequence
<
IRs
...
>
/*old2new*/
)
{
{
...
@@ -143,7 +239,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
...
@@ -143,7 +239,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return
new_array
;
return
new_array
;
}
}
template
<
class
TData
,
index_t
NSize
,
class
ExtractSeq
>
template
<
typename
TData
,
index_t
NSize
,
typename
ExtractSeq
>
__host__
__device__
constexpr
auto
extract_array
(
const
Array
<
TData
,
NSize
>&
old_array
,
ExtractSeq
)
__host__
__device__
constexpr
auto
extract_array
(
const
Array
<
TData
,
NSize
>&
old_array
,
ExtractSeq
)
{
{
Array
<
TData
,
ExtractSeq
::
GetSize
()
>
new_array
;
Array
<
TData
,
ExtractSeq
::
GetSize
()
>
new_array
;
...
@@ -152,12 +248,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
...
@@ -152,12 +248,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_assert
(
new_size
<=
NSize
,
"wrong! too many extract"
);
static_assert
(
new_size
<=
NSize
,
"wrong! too many extract"
);
static_for
<
0
,
new_size
,
1
>
{}([
&
](
auto
I
)
{
new_array
(
I
)
=
old_array
[
ExtractSeq
::
Ge
t
(
I
)];
});
static_for
<
0
,
new_size
,
1
>
{}([
&
](
auto
I
)
{
new_array
(
I
)
=
old_array
[
ExtractSeq
::
A
t
(
I
)];
});
return
new_array
;
return
new_array
;
}
}
template
<
class
F
,
class
X
,
class
Y
,
class
Z
>
// emulate constepxr lambda for array math
// emulate constepxr lambda for array
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
>
struct
lambda_array_math
struct
lambda_array_math
{
{
const
F
&
f
;
const
F
&
f
;
...
@@ -174,13 +271,12 @@ struct lambda_array_math
...
@@ -174,13 +271,12 @@ struct lambda_array_math
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim_
>
)
const
__host__
__device__
constexpr
void
operator
()(
Number
<
IDim_
>
)
const
{
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
z
(
IDim
)
=
f
(
x
[
IDim
],
y
[
IDim
]);
z
.
Set
(
IDim
,
f
(
x
[
IDim
],
y
[
IDim
]));
}
}
};
};
// Array = Array + Array
// Array = Array + Array
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Array
<
TData
,
NSize
>
b
)
{
{
Array
<
TData
,
NSize
>
result
;
Array
<
TData
,
NSize
>
result
;
...
@@ -195,7 +291,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
...
@@ -195,7 +291,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
}
}
// Array = Array - Array
// Array = Array - Array
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
-
(
Array
<
TData
,
NSize
>
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
-
(
Array
<
TData
,
NSize
>
a
,
Array
<
TData
,
NSize
>
b
)
{
{
Array
<
TData
,
NSize
>
result
;
Array
<
TData
,
NSize
>
result
;
...
@@ -210,7 +306,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
...
@@ -210,7 +306,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
}
}
// Array += Array
// Array += Array
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
+=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
+=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
{
{
a
=
a
+
b
;
a
=
a
+
b
;
...
@@ -218,14 +314,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
...
@@ -218,14 +314,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
}
}
// Array -= Array
// Array -= Array
template
<
class
TData
,
index_t
NSize
>
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
-=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
-=
(
Array
<
TData
,
NSize
>&
a
,
Array
<
TData
,
NSize
>
b
)
{
{
a
=
a
-
b
;
a
=
a
-
b
;
return
a
;
return
a
;
}
}
// Array = Array + Sequence
// Array = Array + Sequence
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
__host__
__device__
constexpr
auto
operator
+
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
{
{
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
...
@@ -242,7 +338,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
...
@@ -242,7 +338,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
}
}
// Array = Array - Sequence
// Array = Array - Sequence
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
-
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
__host__
__device__
constexpr
auto
operator
-
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
{
{
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
...
@@ -259,7 +355,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
...
@@ -259,7 +355,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
}
}
// Array = Array * Sequence
// Array = Array * Sequence
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
*
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
__host__
__device__
constexpr
auto
operator
*
(
Array
<
TData
,
NSize
>
a
,
Sequence
<
Is
...
>
b
)
{
{
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
...
@@ -276,7 +372,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
...
@@ -276,7 +372,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
}
}
// Array = Sequence - Array
// Array = Sequence - Array
template
<
class
TData
,
index_t
NSize
,
index_t
...
Is
>
template
<
typename
TData
,
index_t
NSize
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Is
...
>
a
,
Array
<
TData
,
NSize
>
b
)
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Is
...
>
a
,
Array
<
TData
,
NSize
>
b
)
{
{
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
static_assert
(
sizeof
...(
Is
)
==
NSize
,
"wrong! size not the same"
);
...
@@ -292,7 +388,21 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
...
@@ -292,7 +388,21 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
return
result
;
return
result
;
}
}
template
<
class
TData
,
index_t
NSize
,
class
Reduce
>
// Array = Array * TData
template
<
typename
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
*
(
TData
v
,
Array
<
TData
,
NSize
>
a
)
{
Array
<
TData
,
NSize
>
result
;
for
(
index_t
i
=
0
;
i
<
NSize
;
++
i
)
{
result
(
i
)
=
a
[
i
]
*
v
;
}
return
result
;
}
template
<
typename
TData
,
index_t
NSize
,
typename
Reduce
>
__host__
__device__
constexpr
TData
__host__
__device__
constexpr
TData
accumulate_on_array
(
const
Array
<
TData
,
NSize
>&
a
,
Reduce
f
,
TData
init
)
accumulate_on_array
(
const
Array
<
TData
,
NSize
>&
a
,
Reduce
f
,
TData
init
)
{
{
...
@@ -305,89 +415,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
...
@@ -305,89 +415,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
return
result
;
return
result
;
}
}
template
<
class
T
,
index_t
NSize
>
__host__
__device__
void
print_Array
(
const
char
*
s
,
Array
<
T
,
NSize
>
a
)
{
constexpr
index_t
nsize
=
a
.
GetSize
();
static_assert
(
nsize
>
0
&&
nsize
<=
10
,
"wrong!"
);
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
a
[
0
]);
});
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
]);
});
static_if
<
nsize
==
3
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
]);
});
static_if
<
nsize
==
4
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
]);
});
static_if
<
nsize
==
5
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
]);
});
static_if
<
nsize
==
6
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
]);
});
static_if
<
nsize
==
7
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
]);
});
static_if
<
nsize
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
]);
});
static_if
<
nsize
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
]);
});
static_if
<
nsize
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
],
a
[
9
]);
});
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/array_helper.hpp
0 → 100644
View file @
52423948
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "array.hpp"
namespace
ck
{
template
<
index_t
NSize
>
__host__
__device__
void
print_array
(
const
char
*
s
,
Array
<
uint32_t
,
NSize
>
a
)
{
constexpr
index_t
nsize
=
a
.
GetSize
();
static_assert
(
nsize
>
0
&&
nsize
<=
10
,
"wrong!"
);
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
a
[
0
]);
});
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
]);
});
static_if
<
nsize
==
3
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
]);
});
static_if
<
nsize
==
4
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
]);
});
static_if
<
nsize
==
5
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
]);
});
static_if
<
nsize
==
6
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
]);
});
static_if
<
nsize
==
7
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
]);
});
static_if
<
nsize
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
]);
});
static_if
<
nsize
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
]);
});
static_if
<
nsize
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
],
a
[
9
]);
});
}
template
<
index_t
NSize
>
__host__
__device__
void
print_array
(
const
char
*
s
,
Array
<
int32_t
,
NSize
>
a
)
{
constexpr
index_t
nsize
=
a
.
GetSize
();
static_assert
(
nsize
>
0
&&
nsize
<=
10
,
"wrong!"
);
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %d, {%d}
\n
"
,
s
,
nsize
,
a
[
0
]);
});
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %d, {%d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
]);
});
static_if
<
nsize
==
3
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %d, {%d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
]);
});
static_if
<
nsize
==
4
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %d, {%d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
]);
});
static_if
<
nsize
==
5
>
{}([
&
](
auto
)
{
printf
(
"%s size %d, {%d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
]);
});
static_if
<
nsize
==
6
>
{}([
&
](
auto
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
]);
});
static_if
<
nsize
==
7
>
{}([
&
](
auto
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
]);
});
static_if
<
nsize
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
]);
});
static_if
<
nsize
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
]);
});
static_if
<
nsize
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s size %d, {%d %d %d %d %d %d %d %d %d %d}
\n
"
,
s
,
nsize
,
a
[
0
],
a
[
1
],
a
[
2
],
a
[
3
],
a
[
4
],
a
[
5
],
a
[
6
],
a
[
7
],
a
[
8
],
a
[
9
]);
});
}
}
// namespace ck
#endif
composable_kernel/include/utility/common_header.hpp
View file @
52423948
...
@@ -4,16 +4,26 @@
...
@@ -4,16 +4,26 @@
#include "config.hpp"
#include "config.hpp"
#include "utility.hpp"
#include "utility.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
#include "math.hpp"
#include "math.hpp"
#include "vector_type.hpp"
#include "vector_type.hpp"
#include "Sequence.hpp"
#include "sequence.hpp"
#include "Array.hpp"
#include "sequence_helper.hpp"
#include "array.hpp"
#include "array_helper.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#if CK_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp"
#include "amd_inline_asm.hpp"
#endif
#endif
#if CK_USE_AMD_INTRINSIC
#include "amd_intrinsic.hpp"
#endif
#endif
#endif
composable_kernel/include/utility/config_amd.hpp.in
View file @
52423948
...
@@ -4,29 +4,47 @@
...
@@ -4,29 +4,47 @@
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
#define CK_UNSIGNED_INDEX_TYPE 0
#define CK_DEVICE_BACKEND_AMD 1
#define CK_DEVICE_BACKEND_AMD 1
#define CK_USE_AMD_INTRINSIC 1
#define CK_USE_AMD_INLINE_ASM 1
#define CK_USE_AMD_INLINE_ASM 1
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck {
namespace ck {
enum address_space_t
{
generic = 0,
global = 3
};
#if CK_UNSIGNED_INDEX_TYPE
using index_t = uint32_t;
#else
using index_t = int32_t;
#endif
// For some reason, HIP compiler need this definition to generate optimal load and store
// For some reason, HIP compiler need this definition to generate optimal load and store
// instruction
// instruction
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
typedef float float4_t __attribute__((ext_vector_type(4)));
using index_t = uint32_t
;
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)))
;
template <class T>
// data type conversion
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
template <typename T>
struct type_convert
{
{
d += s0 * s1;
template <typename X>
}
__device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
};
} // namespace ck
} // namespace ck
...
...
composable_kernel/include/utility/config_nvidia.hpp.in
View file @
52423948
...
@@ -6,17 +6,30 @@
...
@@ -6,17 +6,30 @@
#include "nvToolsExt.h"
#include "nvToolsExt.h"
#include "helper_cuda.h"
#include "helper_cuda.h"
#define CK_UNSIGNED_INDEX_TYPE 0
#define CK_DEVICE_BACKEND_NVIDIA 1
#define CK_DEVICE_BACKEND_NVIDIA 1
#define CK_USE_AMD_INTRINSIC 0
#define CK_USE_AMD_INLINE_ASM 0
#define CK_USE_AMD_INLINE_ASM 0
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
namespace ck {
namespace ck {
enum address_space_t
{
generic = 0,
global = generic
};
#if CK_UNSIGNED_INDEX_TYPE
using index_t = uint32_t;
#else
using index_t = int32_t;
#endif
// For some reason, CUDA need this definition, otherwise
// For some reason, CUDA need this definition, otherwise
// compiler won't generate optimal load and store instruction, and
// compiler won't generate optimal load and store instruction, and
// kernel would produce wrong result, indicating the compiler fail to generate correct
// kernel would produce wrong result, indicating the compiler fail to generate correct
...
@@ -24,7 +37,16 @@ namespace ck {
...
@@ -24,7 +37,16 @@ namespace ck {
using float2_t = float2;
using float2_t = float2;
using float4_t = float4;
using float4_t = float4;
using index_t = uint32_t;
// data type conversion
template <typename T>
struct type_convert
{
template <typename X>
__device__ T operator()(const X& x) const
{
return static_cast<T>(x);
}
};
template <class T>
template <class T>
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
...
...
composable_kernel/include/utility/functional.hpp
View file @
52423948
...
@@ -2,10 +2,12 @@
...
@@ -2,10 +2,12 @@
#define CK_FUNCTIONAL_HPP
#define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "Sequence.hpp"
#include "sequence.hpp"
#include "type.hpp"
namespace
ck
{
namespace
ck
{
// TODO: right? wrong?
struct
forwarder
struct
forwarder
{
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -17,12 +19,30 @@ struct forwarder
...
@@ -17,12 +19,30 @@ struct forwarder
struct
swallow
struct
swallow
{
{
template
<
class
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
constexpr
swallow
(
Ts
&&
...)
__host__
__device__
constexpr
swallow
(
Ts
&&
...)
{
{
}
}
};
};
template
<
typename
T
>
struct
logical_and
{
constexpr
bool
operator
()(
const
T
&
x
,
const
T
&
y
)
const
{
return
x
&&
y
;
}
};
template
<
typename
T
>
struct
logical_or
{
constexpr
bool
operator
()(
const
T
&
x
,
const
T
&
y
)
const
{
return
x
||
y
;
}
};
template
<
typename
T
>
struct
logical_not
{
constexpr
bool
operator
()(
const
T
&
x
)
const
{
return
!
x
;
}
};
// Emulate if constexpr
// Emulate if constexpr
template
<
bool
>
template
<
bool
>
struct
static_if
;
struct
static_if
;
...
@@ -32,7 +52,7 @@ struct static_if<true>
...
@@ -32,7 +52,7 @@ struct static_if<true>
{
{
using
Type
=
static_if
<
true
>
;
using
Type
=
static_if
<
true
>
;
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
constexpr
auto
operator
()(
F
f
)
const
__host__
__device__
constexpr
auto
operator
()(
F
f
)
const
{
{
// This is a trick for compiler:
// This is a trick for compiler:
...
@@ -43,7 +63,7 @@ struct static_if<true>
...
@@ -43,7 +63,7 @@ struct static_if<true>
return
Type
{};
return
Type
{};
}
}
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
static
constexpr
auto
Else
(
F
)
__host__
__device__
static
constexpr
auto
Else
(
F
)
{
{
return
Type
{};
return
Type
{};
...
@@ -55,13 +75,13 @@ struct static_if<false>
...
@@ -55,13 +75,13 @@ struct static_if<false>
{
{
using
Type
=
static_if
<
false
>
;
using
Type
=
static_if
<
false
>
;
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
constexpr
auto
operator
()(
F
)
const
__host__
__device__
constexpr
auto
operator
()(
F
)
const
{
{
return
Type
{};
return
Type
{};
}
}
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
static
constexpr
auto
Else
(
F
f
)
__host__
__device__
static
constexpr
auto
Else
(
F
f
)
{
{
// This is a trick for compiler:
// This is a trick for compiler:
...
@@ -73,5 +93,23 @@ struct static_if<false>
...
@@ -73,5 +93,23 @@ struct static_if<false>
}
}
};
};
template
<
bool
predicate
,
class
X
,
class
Y
>
struct
conditional
;
template
<
class
X
,
class
Y
>
struct
conditional
<
true
,
X
,
Y
>
{
using
type
=
X
;
};
template
<
class
X
,
class
Y
>
struct
conditional
<
false
,
X
,
Y
>
{
using
type
=
Y
;
};
template
<
bool
predicate
,
class
X
,
class
Y
>
using
conditional_t
=
typename
conditional
<
predicate
,
X
,
Y
>::
type
;
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/functional2.hpp
View file @
52423948
...
@@ -2,10 +2,12 @@
...
@@ -2,10 +2,12 @@
#define CK_FUNCTIONAL2_HPP
#define CK_FUNCTIONAL2_HPP
#include "functional.hpp"
#include "functional.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
namespace
ck
{
namespace
ck
{
namespace
detail
{
template
<
class
>
template
<
class
>
struct
static_for_impl
;
struct
static_for_impl
;
...
@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
...
@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
}
}
};
};
}
// namespace detail
// F signature: F(Number<Iter>)
// F signature: F(Number<Iter>)
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
struct
static_for
struct
static_for
...
@@ -33,38 +37,10 @@ struct static_for
...
@@ -33,38 +37,10 @@ struct static_for
template
<
class
F
>
template
<
class
F
>
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
{
{
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
type
>
{}(
f
);
detail
::
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
type
>
{}(
f
);
}
}
};
};
template
<
class
Seq
,
class
Reduce
>
struct
lambda_accumulate_on_sequence
{
const
Reduce
&
f
;
index_t
&
result
;
__host__
__device__
constexpr
lambda_accumulate_on_sequence
(
const
Reduce
&
f_
,
index_t
&
result_
)
:
f
(
f_
),
result
(
result_
)
{
}
template
<
class
IDim
>
__host__
__device__
constexpr
index_t
operator
()(
IDim
)
const
{
return
result
=
f
(
result
,
Seq
::
Get
(
IDim
{}));
}
};
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
index_t
accumulate_on_sequence
(
Seq
,
Reduce
f
,
Number
<
Init
>
/*initial_value*/
)
{
index_t
result
=
Init
;
static_for
<
0
,
Seq
::
mSize
,
1
>
{}(
lambda_accumulate_on_sequence
<
Seq
,
Reduce
>
(
f
,
result
));
return
result
;
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/functional3.hpp
View file @
52423948
...
@@ -3,25 +3,12 @@
...
@@ -3,25 +3,12 @@
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "
S
equence.hpp"
#include "
s
equence.hpp"
#include "
A
rray.hpp"
#include "
a
rray.hpp"
namespace
ck
{
namespace
ck
{
template
<
class
>
namespace
detail
{
struct
is_static
:
integral_constant
<
bool
,
false
>
{
};
template
<
class
T
,
T
X
>
struct
is_static
<
integral_constant
<
T
,
X
>>
:
integral_constant
<
bool
,
true
>
{
};
template
<
index_t
...
Is
>
struct
is_static
<
Sequence
<
Is
...
>>
:
integral_constant
<
bool
,
true
>
{
};
// RemainLengths: Sequence<...>
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
// Orders: Sequence<...>
...
@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
...
@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
}
}
};
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template
<
class
Lengths
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
GetSize
(),
1
>
::
type
>
struct
static_ford
{
__host__
__device__
constexpr
static_ford
()
{
static_assert
(
Lengths
::
GetSize
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
GetSize
()
==
Orders
::
GetSize
(),
"wrong! inconsistent size"
);
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template
<
class
F
>
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
ReorderGivenNew2Old
(
Orders
{});
static_ford_impl
<
decltype
(
ordered_lengths
),
Orders
>
{}(
f
,
Sequence
<>
{});
}
};
// RemainLengths: Sequence<...>
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
// Orders: Sequence<...>
template
<
class
RemainLengths
,
class
Orders
>
template
<
class
RemainLengths
,
class
Orders
>
...
@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
...
@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
}
}
};
};
}
// namespace detail
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template
<
class
Lengths
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
GetSize
(),
1
>
::
type
>
struct
static_ford
{
__host__
__device__
constexpr
static_ford
()
{
static_assert
(
Lengths
::
GetSize
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
GetSize
()
==
Orders
::
GetSize
(),
"wrong! inconsistent size"
);
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template
<
class
F
>
__host__
__device__
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
ReorderGivenNew2Old
(
Orders
{});
detail
::
static_ford_impl
<
decltype
(
ordered_lengths
),
Orders
>
{}(
f
,
Sequence
<>
{});
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension
// dimension
...
@@ -139,7 +128,8 @@ struct ford
...
@@ -139,7 +128,8 @@ struct ford
for
(
index_t
i
=
0
;
i
<
ordered_lengths
.
Front
();
++
i
)
for
(
index_t
i
=
0
;
i
<
ordered_lengths
.
Front
();
++
i
)
{
{
ford_impl
<
decltype
(
ordered_lengths
.
PopFront
()),
Orders
>
{}(
f
,
Array
<
index_t
,
1
>
{
i
});
detail
::
ford_impl
<
decltype
(
ordered_lengths
.
PopFront
()),
Orders
>
{}(
f
,
Array
<
index_t
,
1
>
{
i
});
}
}
}
}
};
};
...
...
composable_kernel/include/utility/functional4.hpp
0 → 100644
View file @
52423948
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "sequence.hpp"
#include "tuple.hpp"
#include "array.hpp"
namespace
ck
{
namespace
detail
{
template
<
typename
Indices
>
struct
unpack_impl
;
template
<
index_t
...
Is
>
struct
unpack_impl
<
Sequence
<
Is
...
>>
{
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
const
X
&
x
)
const
{
return
f
(
x
.
At
(
Number
<
Is
>
{})...);
}
};
}
// namespace detail
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
unpack
(
F
f
,
const
X
&
x
)
{
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
>
{}(
f
,
x
);
}
}
// namespace ck
#endif
composable_kernel/include/utility/integral_constant.hpp
View file @
52423948
...
@@ -13,51 +13,5 @@ struct integral_constant
...
@@ -13,51 +13,5 @@ struct integral_constant
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
};
};
template
<
class
X
,
class
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
class
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
+
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
<=
X
,
"wrong!"
);
return
Number
<
X
-
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
*
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
*
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
/
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
/
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
%
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
%
Y
>
{};
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/math.hpp
View file @
52423948
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "config.hpp"
#include "config.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
namespace
ck
{
namespace
ck
{
namespace
math
{
namespace
math
{
...
@@ -31,6 +32,12 @@ struct multiplies
...
@@ -31,6 +32,12 @@ struct multiplies
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
};
};
template
<
class
T
>
struct
maxer
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>=
b
?
a
:
b
;
}
};
template
<
class
T
>
template
<
class
T
>
struct
integer_divide_ceiler
struct
integer_divide_ceiler
{
{
...
@@ -98,6 +105,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
...
@@ -98,6 +105,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
return
max
(
x
,
xs
...);
return
max
(
x
,
xs
...);
}
}
template
<
class
T
>
struct
equal
{
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
==
y
;
}
};
template
<
class
T
>
struct
less
{
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
<
y
;
}
};
}
// namespace math
}
// namespace math
}
// namspace ck
}
// namspace ck
...
...
composable_kernel/include/utility/number.hpp
0 → 100644
View file @
52423948
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
#include "integral_constant.hpp"
namespace
ck
{
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
+
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
<=
X
,
"wrong!"
);
return
Number
<
X
-
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
*
(
Number
<
X
>
,
Number
<
Y
>
)
{
return
Number
<
X
*
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
/
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
/
Y
>
{};
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
%
(
Number
<
X
>
,
Number
<
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
return
Number
<
X
%
Y
>
{};
}
}
// namespace ck
#endif
composable_kernel/include/utility/
S
equence.hpp
→
composable_kernel/include/utility/
s
equence.hpp
View file @
52423948
...
@@ -2,29 +2,34 @@
...
@@ -2,29 +2,34 @@
#define CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "math.hpp"
namespace
ck
{
namespace
ck
{
template
<
index_t
,
index_t
,
index_t
>
struct
static_for
;
template
<
index_t
...>
template
<
index_t
...>
struct
Sequence
;
struct
Sequence
;
template
<
class
Seq
,
index_t
I
>
template
<
typename
Seq
,
index_t
I
>
struct
sequence_split
;
struct
sequence_split
;
template
<
class
>
template
<
typename
>
struct
sequence_reverse
;
struct
sequence_reverse
;
template
<
class
>
template
<
typename
>
struct
sequence_map_inverse
;
struct
sequence_map_inverse
;
template
<
class
>
template
<
typename
>
struct
is_valid_sequence_map
;
struct
is_valid_sequence_map
;
template
<
index_t
I
,
index_t
...
Is
>
template
<
index_t
I
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence_pop_front
(
Sequence
<
I
,
Is
...
>
);
__host__
__device__
constexpr
auto
sequence_pop_front
(
Sequence
<
I
,
Is
...
>
);
template
<
class
Seq
>
template
<
typename
Seq
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
);
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
);
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
...
@@ -35,9 +40,11 @@ struct Sequence
...
@@ -35,9 +40,11 @@ struct Sequence
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
__host__
__device__
static
constexpr
auto
GetSize
()
{
return
Number
<
mSize
>
{};
}
__host__
__device__
static
constexpr
auto
Size
()
{
return
Number
<
mSize
>
{};
}
__host__
__device__
static
constexpr
auto
GetSize
()
{
return
Size
();
}
__host__
__device__
static
constexpr
index_t
GetImpl
(
index_t
I
)
__host__
__device__
static
constexpr
index_t
At
(
index_t
I
)
{
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
...
@@ -45,23 +52,24 @@ struct Sequence
...
@@ -45,23 +52,24 @@ struct Sequence
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
static
constexpr
auto
Ge
t
(
Number
<
I
>
)
__host__
__device__
static
constexpr
auto
A
t
(
Number
<
I
>
)
{
{
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
static_assert
(
I
<
mSize
,
"wrong! I too large"
);
return
Number
<
GetImpl
(
Number
<
I
>
{}
)
>
{};
return
Number
<
At
(
I
)
>
{};
}
}
__host__
__device__
static
constexpr
auto
Get
(
index_t
I
)
{
return
GetImpl
(
I
);
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
[]
(
Number
<
I
>
)
const
__host__
__device__
static
constexpr
auto
Get
(
Number
<
I
>
)
{
{
return
Ge
t
(
Number
<
I
>
{});
return
A
t
(
Number
<
I
>
{});
}
}
// make sure I is constepxr if you want a constexpr return type
template
<
typename
I
>
__host__
__device__
constexpr
index_t
operator
[](
index_t
I
)
const
{
return
GetImpl
(
I
);
}
__host__
__device__
constexpr
auto
operator
[](
I
i
)
const
{
return
At
(
i
);
}
template
<
index_t
...
IRs
>
template
<
index_t
...
IRs
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
...
@@ -71,14 +79,14 @@ struct Sequence
...
@@ -71,14 +79,14 @@ struct Sequence
static_assert
(
is_valid_sequence_map
<
Sequence
<
IRs
...
>>::
value
,
"wrong! invalid reorder map"
);
static_assert
(
is_valid_sequence_map
<
Sequence
<
IRs
...
>>::
value
,
"wrong! invalid reorder map"
);
return
Sequence
<
Type
::
Ge
t
(
Number
<
IRs
>
{})...
>
{};
return
Sequence
<
Type
::
A
t
(
Number
<
IRs
>
{})...
>
{};
}
}
// MapOld2New is Sequence<...>
// MapOld2New is Sequence<...>
template
<
class
MapOld2New
>
template
<
typename
MapOld2New
>
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
__host__
__device__
static
constexpr
auto
ReorderGivenOld2New
(
MapOld2New
)
{
{
static_assert
(
MapOld2New
::
Get
Size
()
==
Get
Size
(),
static_assert
(
MapOld2New
::
Size
()
==
Size
(),
"wrong! reorder map should have the same size as Sequence to be rerodered"
);
"wrong! reorder map should have the same size as Sequence to be rerodered"
);
static_assert
(
is_valid_sequence_map
<
MapOld2New
>::
value
,
"wrong! invalid reorder map"
);
static_assert
(
is_valid_sequence_map
<
MapOld2New
>::
value
,
"wrong! invalid reorder map"
);
...
@@ -94,13 +102,13 @@ struct Sequence
...
@@ -94,13 +102,13 @@ struct Sequence
__host__
__device__
static
constexpr
auto
Front
()
__host__
__device__
static
constexpr
auto
Front
()
{
{
static_assert
(
mSize
>
0
,
"wrong!"
);
static_assert
(
mSize
>
0
,
"wrong!"
);
return
Ge
t
(
Number
<
0
>
{});
return
A
t
(
Number
<
0
>
{});
}
}
__host__
__device__
static
constexpr
auto
Back
()
__host__
__device__
static
constexpr
auto
Back
()
{
{
static_assert
(
mSize
>
0
,
"wrong!"
);
static_assert
(
mSize
>
0
,
"wrong!"
);
return
Ge
t
(
Number
<
mSize
-
1
>
{});
return
A
t
(
Number
<
mSize
-
1
>
{});
}
}
__host__
__device__
static
constexpr
auto
PopFront
()
{
return
sequence_pop_front
(
Type
{});
}
__host__
__device__
static
constexpr
auto
PopFront
()
{
return
sequence_pop_front
(
Type
{});
}
...
@@ -134,28 +142,28 @@ struct Sequence
...
@@ -134,28 +142,28 @@ struct Sequence
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
Ns
>
...)
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
Ns
>
...)
{
{
return
Sequence
<
Type
::
Ge
t
(
Number
<
Ns
>
{})...
>
{};
return
Sequence
<
Type
::
A
t
(
Number
<
Ns
>
{})...
>
{};
}
}
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
__host__
__device__
static
constexpr
auto
Extract
(
Sequence
<
Ns
...
>
)
__host__
__device__
static
constexpr
auto
Extract
(
Sequence
<
Ns
...
>
)
{
{
return
Sequence
<
Type
::
Ge
t
(
Number
<
Ns
>
{})...
>
{};
return
Sequence
<
Type
::
A
t
(
Number
<
Ns
>
{})...
>
{};
}
}
template
<
index_t
I
,
index_t
X
>
template
<
index_t
I
,
index_t
X
>
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
)
__host__
__device__
static
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
)
{
{
static_assert
(
I
<
Get
Size
(),
"wrong!"
);
static_assert
(
I
<
Size
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
Type
,
I
>
;
using
seq_split
=
sequence_split
<
Type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
SeqT
ype
0
{};
constexpr
auto
seq_left
=
typename
seq_split
::
left_t
ype
{};
constexpr
auto
seq_right
=
typename
seq_split
::
SeqT
ype
1
{}.
PopFront
();
constexpr
auto
seq_right
=
typename
seq_split
::
right_t
ype
{}.
PopFront
();
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
return
seq_left
.
PushBack
(
Number
<
X
>
{}).
PushBack
(
seq_right
);
}
}
template
<
class
F
>
template
<
typename
F
>
__host__
__device__
static
constexpr
auto
Transform
(
F
f
)
__host__
__device__
static
constexpr
auto
Transform
(
F
f
)
{
{
return
Sequence
<
f
(
Is
)...
>
{};
return
Sequence
<
f
(
Is
)...
>
{};
...
@@ -163,8 +171,11 @@ struct Sequence
...
@@ -163,8 +171,11 @@ struct Sequence
};
};
// merge sequence
// merge sequence
template
<
class
,
class
>
template
<
typename
Seq
,
typename
...
Seqs
>
struct
sequence_merge
;
struct
sequence_merge
{
using
type
=
typename
sequence_merge
<
Seq
,
typename
sequence_merge
<
Seqs
...
>::
type
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
struct
sequence_merge
<
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
struct
sequence_merge
<
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
...
@@ -172,35 +183,41 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
...
@@ -172,35 +183,41 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using
type
=
Sequence
<
Xs
...,
Ys
...
>
;
using
type
=
Sequence
<
Xs
...,
Ys
...
>
;
};
};
// generate sequence
template
<
typename
Seq
>
template
<
index_t
IBegin
,
index_t
NRemain
,
class
F
>
struct
sequence_merge
<
Seq
>
struct
sequence_gen_impl
{
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
using
type
=
Seq
;
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
using
type
=
typename
sequence_merge
<
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
F
>::
type
,
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
F
>::
type
>::
type
;
};
};
template
<
index_t
I
,
class
F
>
// generate sequence
struct
sequence_gen_impl
<
I
,
1
,
F
>
template
<
index_t
NSize
,
typename
F
>
struct
sequence_gen
{
{
static
constexpr
index_t
Is
=
F
{}(
Number
<
I
>
{});
template
<
index_t
IBegin
,
index_t
NRemain
,
typename
G
>
using
type
=
Sequence
<
Is
>
;
struct
sequence_gen_impl
};
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
template
<
index_t
I
,
class
F
>
using
type
=
typename
sequence_merge
<
struct
sequence_gen_impl
<
I
,
0
,
F
>
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
G
>::
type
,
{
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
G
>::
type
>::
type
;
using
type
=
Sequence
<>
;
};
};
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
1
,
G
>
{
static
constexpr
index_t
Is
=
G
{}(
Number
<
I
>
{});
using
type
=
Sequence
<
Is
>
;
};
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
0
,
G
>
{
using
type
=
Sequence
<>
;
};
template
<
index_t
NSize
,
class
F
>
struct
sequence_gen
{
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
};
};
...
@@ -232,10 +249,10 @@ struct uniform_sequence_gen
...
@@ -232,10 +249,10 @@ struct uniform_sequence_gen
};
};
// reverse inclusive scan (with init) sequence
// reverse inclusive scan (with init) sequence
template
<
class
,
class
,
index_t
>
template
<
typename
,
typename
,
index_t
>
struct
sequence_reverse_inclusive_scan
;
struct
sequence_reverse_inclusive_scan
;
template
<
index_t
I
,
index_t
...
Is
,
class
Reduce
,
index_t
Init
>
template
<
index_t
I
,
index_t
...
Is
,
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
,
Is
...
>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
,
Is
...
>
,
Reduce
,
Init
>
{
{
using
old_scan
=
typename
sequence_reverse_inclusive_scan
<
Sequence
<
Is
...
>
,
Reduce
,
Init
>::
type
;
using
old_scan
=
typename
sequence_reverse_inclusive_scan
<
Sequence
<
Is
...
>
,
Reduce
,
Init
>::
type
;
...
@@ -245,41 +262,41 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
...
@@ -245,41 +262,41 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
using
type
=
typename
sequence_merge
<
Sequence
<
new_reduce
>
,
old_scan
>::
type
;
using
type
=
typename
sequence_merge
<
Sequence
<
new_reduce
>
,
old_scan
>::
type
;
};
};
template
<
index_t
I
,
class
Reduce
,
index_t
Init
>
template
<
index_t
I
,
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
>
,
Reduce
,
Init
>
{
{
using
type
=
Sequence
<
Reduce
{}(
I
,
Init
)
>
;
using
type
=
Sequence
<
Reduce
{}(
I
,
Init
)
>
;
};
};
template
<
class
Reduce
,
index_t
Init
>
template
<
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<>
,
Reduce
,
Init
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<>
,
Reduce
,
Init
>
{
{
using
type
=
Sequence
<>
;
using
type
=
Sequence
<>
;
};
};
// split sequence
// split sequence
template
<
class
Seq
,
index_t
I
>
template
<
typename
Seq
,
index_t
I
>
struct
sequence_split
struct
sequence_split
{
{
static
constexpr
index_t
NSize
=
Seq
{}.
Get
Size
();
static
constexpr
index_t
NSize
=
Seq
{}.
Size
();
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
SeqT
ype
0
=
decltype
(
Seq
::
Extract
(
range0
{}));
using
left_t
ype
=
decltype
(
Seq
::
Extract
(
range0
{}));
using
SeqT
ype
1
=
decltype
(
Seq
::
Extract
(
range1
{}));
using
right_t
ype
=
decltype
(
Seq
::
Extract
(
range1
{}));
};
};
// reverse sequence
// reverse sequence
template
<
class
Seq
>
template
<
typename
Seq
>
struct
sequence_reverse
struct
sequence_reverse
{
{
static
constexpr
index_t
NSize
=
Seq
{}.
Get
Size
();
static
constexpr
index_t
NSize
=
Seq
{}.
Size
();
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
type
=
typename
sequence_merge
<
using
type
=
typename
sequence_merge
<
typename
sequence_reverse
<
typename
seq_split
::
SeqT
ype
1
>::
type
,
typename
sequence_reverse
<
typename
seq_split
::
right_t
ype
>::
type
,
typename
sequence_reverse
<
typename
seq_split
::
SeqT
ype
0
>::
type
>::
type
;
typename
sequence_reverse
<
typename
seq_split
::
left_t
ype
>::
type
>::
type
;
};
};
template
<
index_t
I
>
template
<
index_t
I
>
...
@@ -294,44 +311,291 @@ struct sequence_reverse<Sequence<I0, I1>>
...
@@ -294,44 +311,291 @@ struct sequence_reverse<Sequence<I0, I1>>
using
type
=
Sequence
<
I1
,
I0
>
;
using
type
=
Sequence
<
I1
,
I0
>
;
};
};
template
<
class
Seq
>
#if 1
struct
is_valid_sequence_map
template
<
typename
Reduce
,
typename
Seq
,
typename
...
Seqs
>
struct
sequence_reduce
{
{
// not implemented yet, always return true
using
type
=
typename
sequence_reduce
<
Reduce
,
static
constexpr
integral_constant
<
bool
,
true
>
value
=
integral_constant
<
bool
,
true
>
{};
Seq
,
typename
sequence_reduce
<
Reduce
,
Seqs
...
>::
type
>::
type
;
};
// TODO: add proper check for is_valid, something like:
template
<
typename
Reduce
,
index_t
...
Xs
,
index_t
...
Ys
>
// static constexpr bool value =
struct
sequence_reduce
<
Reduce
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type,
{
// typename sequence_sort<Seq>::SortedSeqType>{}
;
using
type
=
Sequence
<
Reduce
{}(
Xs
,
Ys
)...
>
;
};
};
template
<
class
X2Y
,
class
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
template
<
typename
Reduce
,
typename
Seq
>
struct
sequence_
map_inverse_impl
struct
sequence_
reduce
<
Reduce
,
Seq
>
{
{
private:
using
type
=
Seq
;
static
constexpr
auto
new_y2x
=
};
WorkingY2X
::
Modify
(
X2Y
::
Get
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
#endif
public:
template
<
typename
Values
,
typename
Ids
,
typename
Compare
>
using
type
=
struct
sequence_sort_impl
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
{
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
{
static
constexpr
bool
choose_left
=
LeftValues
::
Front
()
<
RightValues
::
Front
();
static
constexpr
index_t
chosen_value
=
choose_left
?
LeftValues
::
Front
()
:
RightValues
::
Front
();
static
constexpr
index_t
chosen_id
=
choose_left
?
LeftIds
::
Front
()
:
RightIds
::
Front
();
using
new_merged_values
=
decltype
(
MergedValues
::
PushBack
(
Number
<
chosen_value
>
{}));
using
new_merged_ids
=
decltype
(
MergedIds
::
PushBack
(
Number
<
chosen_id
>
{}));
using
new_left_values
=
typename
conditional
<
choose_left
,
decltype
(
LeftValues
::
PopFront
()),
LeftValues
>::
type
;
using
new_left_ids
=
typename
conditional
<
choose_left
,
decltype
(
LeftIds
::
PopFront
()),
LeftIds
>::
type
;
using
new_right_values
=
typename
conditional
<
choose_left
,
RightValues
,
decltype
(
RightValues
::
PopFront
())
>::
type
;
using
new_right_ids
=
typename
conditional
<
choose_left
,
RightIds
,
decltype
(
RightIds
::
PopFront
())
>::
type
;
using
merge
=
sorted_sequence_merge_impl
<
new_left_values
,
new_left_ids
,
new_right_values
,
new_right_ids
,
new_merged_values
,
new_merged_ids
,
Comp
>
;
// this is output
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
Sequence
<>
,
Sequence
<>
,
MergedValues
,
MergedIds
,
Comp
>
{
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
LeftValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
LeftIds
>::
type
;
};
template
<
typename
RightValues
,
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
Sequence
<>
,
Sequence
<>
,
RightValues
,
RightIds
,
MergedValues
,
MergedIds
,
Comp
>
{
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
RightValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
RightIds
>::
type
;
};
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
Comp
>
struct
sorted_sequence_merge
{
using
merge
=
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
RightValues
,
RightIds
,
Sequence
<>
,
Sequence
<>
,
Comp
>
;
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
static
constexpr
index_t
nsize
=
Values
::
Size
();
using
split_unsorted_values
=
sequence_split
<
Values
,
nsize
/
2
>
;
using
split_unsorted_ids
=
sequence_split
<
Ids
,
nsize
/
2
>
;
using
left_unsorted_values
=
typename
split_unsorted_values
::
left_type
;
using
left_unsorted_ids
=
typename
split_unsorted_ids
::
left_type
;
using
left_sort
=
sequence_sort_impl
<
left_unsorted_values
,
left_unsorted_ids
,
Compare
>
;
using
left_sorted_values
=
typename
left_sort
::
sorted_values
;
using
left_sorted_ids
=
typename
left_sort
::
sorted_ids
;
using
right_unsorted_values
=
typename
split_unsorted_values
::
right_type
;
using
right_unsorted_ids
=
typename
split_unsorted_ids
::
right_type
;
using
right_sort
=
sequence_sort_impl
<
right_unsorted_values
,
right_unsorted_ids
,
Compare
>
;
using
right_sorted_values
=
typename
right_sort
::
sorted_values
;
using
right_sorted_ids
=
typename
right_sort
::
sorted_ids
;
using
merged_sorted
=
sorted_sequence_merge
<
left_sorted_values
,
left_sorted_ids
,
right_sorted_values
,
right_sorted_ids
,
Compare
>
;
using
sorted_values
=
typename
merged_sorted
::
merged_values
;
using
sorted_ids
=
typename
merged_sorted
::
merged_ids
;
};
template
<
index_t
ValueX
,
index_t
ValueY
,
index_t
IdX
,
index_t
IdY
,
typename
Compare
>
struct
sequence_sort_impl
<
Sequence
<
ValueX
,
ValueY
>
,
Sequence
<
IdX
,
IdY
>
,
Compare
>
{
static
constexpr
bool
choose_x
=
Compare
{}(
ValueX
,
ValueY
);
using
sorted_values
=
typename
conditional
<
choose_x
,
Sequence
<
ValueX
,
ValueY
>
,
Sequence
<
ValueY
,
ValueX
>>::
type
;
using
sorted_ids
=
typename
conditional
<
choose_x
,
Sequence
<
IdX
,
IdY
>
,
Sequence
<
IdY
,
IdX
>>::
type
;
};
template
<
index_t
Value
,
index_t
Id
,
typename
Compare
>
struct
sequence_sort_impl
<
Sequence
<
Value
>
,
Sequence
<
Id
>
,
Compare
>
{
using
sorted_values
=
Sequence
<
Value
>
;
using
sorted_ids
=
Sequence
<
Id
>
;
};
};
template
<
class
X2Y
,
class
WorkingY2X
,
index_t
XBegin
>
template
<
typename
Compare
>
struct
sequence_
map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
struct
sequence_
sort_impl
<
Sequence
<>
,
Sequence
<>
,
Compare
>
{
{
using
type
=
WorkingY2X
;
using
sorted_values
=
Sequence
<>
;
using
sorted_ids
=
Sequence
<>
;
};
};
template
<
class
X2Y
>
template
<
typename
Values
,
typename
Compare
>
struct
sequence_sort
{
using
unsorted_ids
=
typename
arithmetic_sequence_gen
<
0
,
Values
::
Size
(),
1
>::
type
;
using
sort
=
sequence_sort_impl
<
Values
,
unsorted_ids
,
Compare
>
;
// this is output
using
type
=
typename
sort
::
sorted_values
;
using
sorted2unsorted_map
=
typename
sort
::
sorted_ids
;
};
template
<
typename
Values
,
typename
Less
,
typename
Equal
>
struct
sequence_unique_sort
{
template
<
typename
RemainValues
,
typename
RemainIds
,
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
{
static
constexpr
index_t
current_value
=
RemainValues
::
Front
();
static
constexpr
index_t
current_id
=
RemainIds
::
Front
();
static
constexpr
bool
is_unique_value
=
(
current_value
!=
UniquifiedValues
::
Back
());
using
new_remain_values
=
decltype
(
RemainValues
::
PopFront
());
using
new_remain_ids
=
decltype
(
RemainIds
::
PopFront
());
using
new_uniquified_values
=
typename
conditional
<
is_unique_value
,
decltype
(
UniquifiedValues
::
PushBack
(
Number
<
current_value
>
{})),
UniquifiedValues
>::
type
;
using
new_uniquified_ids
=
typename
conditional
<
is_unique_value
,
decltype
(
UniquifiedIds
::
PushBack
(
Number
<
current_id
>
{})),
UniquifiedIds
>::
type
;
using
uniquify
=
sorted_sequence_uniquify_impl
<
new_remain_values
,
new_remain_ids
,
new_uniquified_values
,
new_uniquified_ids
,
Eq
>
;
// this is output
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
template
<
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
<
Sequence
<>
,
Sequence
<>
,
UniquifiedValues
,
UniquifiedIds
,
Eq
>
{
using
uniquified_values
=
UniquifiedValues
;
using
uniquified_ids
=
UniquifiedIds
;
};
template
<
typename
SortedValues
,
typename
SortedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify
{
using
uniquify
=
sorted_sequence_uniquify_impl
<
decltype
(
SortedValues
::
PopFront
()),
decltype
(
SortedIds
::
PopFront
()),
Sequence
<
SortedValues
::
Front
()
>
,
Sequence
<
SortedIds
::
Front
()
>
,
Eq
>
;
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
using
sort
=
sequence_sort
<
Values
,
Less
>
;
using
sorted_values
=
typename
sort
::
type
;
using
sorted_ids
=
typename
sort
::
sorted2unsorted_map
;
using
uniquify
=
sorted_sequence_uniquify
<
sorted_values
,
sorted_ids
,
Equal
>
;
// this is output
using
type
=
typename
uniquify
::
uniquified_values
;
using
sorted2unsorted_map
=
typename
uniquify
::
uniquified_ids
;
};
template
<
typename
SeqMap
>
struct
is_valid_sequence_map
:
is_same
<
typename
arithmetic_sequence_gen
<
0
,
SeqMap
::
Size
(),
1
>::
type
,
typename
sequence_sort
<
SeqMap
,
math
::
less
<
index_t
>>::
type
>
{
};
template
<
typename
SeqMap
>
struct
sequence_map_inverse
struct
sequence_map_inverse
{
{
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
struct
sequence_map_inverse_impl
{
static
constexpr
auto
new_y2x
=
WorkingY2X
::
Modify
(
X2Y
::
At
(
Number
<
XBegin
>
{}),
Number
<
XBegin
>
{});
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
};
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
{
using
type
=
WorkingY2X
;
};
using
type
=
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
typename
sequence_map_inverse_impl
<
SeqMap
,
typename
uniform_sequence_gen
<
X2Y
::
Get
Size
(),
0
>::
type
,
typename
uniform_sequence_gen
<
SeqMap
::
Size
(),
0
>::
type
,
0
,
0
,
X2Y
::
Get
Size
()
>::
type
;
SeqMap
::
Size
()
>::
type
;
};
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
...
@@ -442,20 +706,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
...
@@ -442,20 +706,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return
Sequence
<
Is
...
>
{};
return
Sequence
<
Is
...
>
{};
}
}
template
<
class
Seq
>
template
<
typename
Seq
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
)
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
)
{
{
static_assert
(
Seq
::
Get
Size
()
>
0
,
"wrong! cannot pop an empty Sequence!"
);
static_assert
(
Seq
::
Size
()
>
0
,
"wrong! cannot pop an empty Sequence!"
);
return
sequence_pop_front
(
Seq
::
Reverse
()).
Reverse
();
return
sequence_pop_front
(
Seq
::
Reverse
()).
Reverse
();
}
}
template
<
class
F
,
index_t
...
Xs
>
template
<
typename
...
Seqs
>
__host__
__device__
constexpr
auto
merge_sequences
(
Seqs
...)
{
return
typename
sequence_merge
<
Seqs
...
>::
type
{};
}
template
<
typename
F
,
index_t
...
Xs
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
)
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
)
{
{
return
Sequence
<
f
(
Xs
)...
>
{};
return
Sequence
<
f
(
Xs
)...
>
{};
}
}
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
>
template
<
typename
F
,
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
static_assert
(
Sequence
<
Xs
...
>::
mSize
==
Sequence
<
Ys
...
>::
mSize
,
"Dim not the same"
);
static_assert
(
Sequence
<
Xs
...
>::
mSize
==
Sequence
<
Ys
...
>::
mSize
,
"Dim not the same"
);
...
@@ -463,7 +733,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
...
@@ -463,7 +733,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
}
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
template
<
typename
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
Sequence
<
Zs
...
>
)
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
Sequence
<
Zs
...
>
)
{
{
...
@@ -474,52 +744,123 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
...
@@ -474,52 +744,123 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
}
}
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
auto
reverse_inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
__host__
__device__
constexpr
auto
reverse_inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
{
{
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
,
Init
>::
type
{};
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
,
Init
>::
type
{};
}
}
template
<
class
Seq
,
class
Reduce
,
index_t
Init
>
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
auto
inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
__host__
__device__
constexpr
auto
inclusive_scan_sequence
(
Seq
,
Reduce
,
Number
<
Init
>
)
{
{
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
return
reverse_inclusive_scan_sequence
(
Seq
{}.
Reverse
(),
Reduce
{},
Number
<
Init
>
{}).
Reverse
();
}
}
template
<
index_t
...
Xs
>
template
<
typename
Seq
,
index_t
...
Is
>
__host__
__device__
void
print_Sequence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
__host__
__device__
constexpr
auto
pick_sequence_elements_by_ids
(
Seq
,
Sequence
<
Is
...
>
/* ids */
)
{
return
Sequence
<
Seq
::
At
(
Number
<
Is
>
{})...
>
{};
}
#if 1
namespace
detail
{
template
<
typename
WorkSeq
,
typename
RemainSeq
,
typename
RemainMask
>
struct
pick_sequence_elements_by_mask_impl
{
{
constexpr
index_t
nsize
=
Sequence
<
Xs
...
>::
GetSize
();
using
new_work_seq
=
typename
conditional
<
RemainMask
::
Front
(),
decltype
(
WorkSeq
::
PushBack
(
RemainSeq
::
Front
())),
WorkSeq
>::
type
;
static_assert
(
nsize
<=
10
,
"wrong!"
);
using
type
=
typename
pick_sequence_elements_by_mask_impl
<
new_work_seq
,
decltype
(
RemainSeq
::
PopFront
()),
decltype
(
RemainMask
::
PopFront
())
>::
type
;
};
template
<
typename
WorkSeq
>
struct
pick_sequence_elements_by_mask_impl
<
WorkSeq
,
Sequence
<>
,
Sequence
<>>
{
using
type
=
WorkSeq
;
};
}
// namespace detail
template
<
typename
Seq
,
typename
Mask
>
__host__
__device__
constexpr
auto
pick_sequence_elements_by_mask
(
Seq
,
Mask
)
{
static_assert
(
Seq
::
Size
()
==
Mask
::
Size
(),
"wrong!"
);
return
typename
detail
::
pick_sequence_elements_by_mask_impl
<
Sequence
<>
,
Seq
,
Mask
>::
type
{};
}
namespace
detail
{
template
<
typename
WorkSeq
,
typename
RemainValues
,
typename
RemainIds
>
struct
modify_sequence_elements_by_ids_impl
{
using
new_work_seq
=
decltype
(
WorkSeq
::
Modify
(
RemainIds
::
Front
(),
RemainValues
::
Front
()));
using
type
=
typename
modify_sequence_elements_by_ids_impl
<
new_work_seq
,
decltype
(
RemainValues
::
PopFront
()),
decltype
(
RemainIds
::
PopFront
())
>::
type
;
};
template
<
typename
WorkSeq
>
struct
modify_sequence_elements_by_ids_impl
<
WorkSeq
,
Sequence
<>
,
Sequence
<>>
{
using
type
=
WorkSeq
;
};
}
// namespace detail
static_if
<
nsize
==
0
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {}
\n
"
,
s
,
nsize
,
Xs
...);
});
template
<
typename
Seq
,
typename
Values
,
typename
Ids
>
__host__
__device__
constexpr
auto
modify_sequence_elements_by_ids
(
Seq
,
Values
,
Ids
)
{
static_assert
(
Values
::
Size
()
==
Ids
::
Size
()
&&
Seq
::
Size
()
>=
Values
::
Size
(),
"wrong!"
);
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
Xs
...);
});
return
typename
detail
::
modify_sequence_elements_by_ids_impl
<
Seq
,
Values
,
Ids
>::
type
{};
}
#endif
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
__host__
__device__
constexpr
index_t
reduce_on_sequence
(
Seq
,
Reduce
f
,
Number
<
Init
>
/*initial_value*/
)
{
index_t
result
=
Init
;
static_if
<
nsize
==
3
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
for
(
index_t
i
=
0
;
i
<
Seq
::
Size
();
++
i
)
{
result
=
f
(
result
,
Seq
::
At
(
i
));
}
static_if
<
nsize
==
4
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
return
result
;
}
static_if
<
nsize
==
5
>
{}(
// TODO: a generic any_of for any container
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
template
<
typename
Seq
,
typename
F
>
__host__
__device__
constexpr
bool
sequence_any_of
(
Seq
,
F
f
)
{
bool
flag
=
false
;
static_if
<
nsize
==
6
>
{}(
for
(
index_t
i
=
0
;
i
<
Seq
::
Size
();
++
i
)
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
{
flag
=
flag
||
f
(
Seq
::
At
(
i
));
}
static_if
<
nsize
==
7
>
{}(
return
flag
;
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
}
static_if
<
nsize
==
8
>
{}(
// TODO: a generic all_of for any container
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
template
<
typename
Seq
,
typename
F
>
__host__
__device__
constexpr
bool
sequence_all_of
(
Seq
,
F
f
)
{
bool
flag
=
true
;
static_if
<
nsize
==
9
>
{}(
for
(
index_t
i
=
0
;
i
<
Seq
::
Size
();
++
i
)
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
{
flag
=
flag
&&
f
(
Seq
::
At
(
i
));
}
static_if
<
nsize
==
10
>
{}(
return
flag
;
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/sequence_helper.hpp
0 → 100644
View file @
52423948
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "sequence.hpp"
namespace
ck
{
template
<
index_t
...
Xs
>
__host__
__device__
void
print_sequence
(
const
char
*
s
,
Sequence
<
Xs
...
>
)
{
constexpr
index_t
nsize
=
Sequence
<
Xs
...
>::
Size
();
static_assert
(
nsize
<=
10
,
"wrong!"
);
static_if
<
nsize
==
0
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
3
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
4
>
{}([
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
5
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
6
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
7
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
8
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
9
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
static_if
<
nsize
==
10
>
{}(
[
&
](
auto
)
{
printf
(
"%s size %u, {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nsize
,
Xs
...);
});
}
}
// namespace ck
#endif
composable_kernel/include/utility/tuple.hpp
0 → 100644
View file @
52423948
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
#include "type.hpp"
#include "sequence.hpp"
namespace
ck
{
namespace
detail
{
template
<
index_t
>
struct
TupleElementKey
{
};
template
<
typename
Key
,
typename
Data
>
struct
TupleElement
{
__host__
__device__
explicit
constexpr
TupleElement
()
:
mData
()
{}
template
<
typename
T
>
__host__
__device__
explicit
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
static_cast
<
T
&&>
(
v
))
{
}
Data
mData
;
};
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
const
Data
&
get_tuple_element
(
const
TupleElement
<
Key
,
Data
>&
x
)
{
return
x
.
mData
;
}
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&
x
)
{
return
x
.
mData
;
}
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&&
x
)
{
return
static_cast
<
Data
&&>
(
x
.
mData
);
}
template
<
typename
Indices
,
typename
...
Xs
>
struct
TupleImpl
;
template
<
index_t
...
Is
,
typename
...
Xs
>
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
...
{
__host__
__device__
explicit
constexpr
TupleImpl
()
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
()...
{
}
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
static_cast
<
Ys
&&>
(
ys
))...
{
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
const
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
}
};
}
// namespace detail
template
<
typename
...
Xs
>
struct
Tuple
:
detail
::
TupleImpl
<
typename
arithmetic_sequence_gen
<
0
,
sizeof
...(
Xs
),
1
>::
type
,
Xs
...
>
{
using
base
=
detail
::
TupleImpl
<
typename
arithmetic_sequence_gen
<
0
,
sizeof
...(
Xs
),
1
>::
type
,
Xs
...
>
;
template
<
typename
...
Ys
>
__host__
__device__
explicit
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
static_cast
<
Ys
&&>
(
ys
)...)
{
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
};
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cv_t
<
remove_reference_t
<
Xs
>>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
Sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}))...);
}
template
<
typename
F
,
typename
X
,
typename
Y
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
Sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}),
y
.
At
(
Number
<
Is
>
{}))...);
}
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
const
Z
&
z
,
Sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}),
y
.
At
(
Number
<
Is
>
{}),
z
.
At
(
Number
<
Is
>
{}))...);
}
}
// namespace detail
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
template
<
typename
F
,
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
,
const
Y
&
y
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
y
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
>
__host__
__device__
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
const
Z
&
z
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
}
// namespace ck
#endif
composable_kernel/include/utility/type.hpp
0 → 100644
View file @
52423948
#ifndef CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
namespace
ck
{
template
<
index_t
...
Is
>
struct
Sequence
;
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
typename
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
>
struct
is_static
:
integral_constant
<
bool
,
false
>
{
};
template
<
typename
T
,
T
X
>
struct
is_static
<
integral_constant
<
T
,
X
>>
:
integral_constant
<
bool
,
true
>
{
};
template
<
index_t
...
Is
>
struct
is_static
<
Sequence
<
Is
...
>>
:
integral_constant
<
bool
,
true
>
{
};
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
}
// namespace ck
#endif
composable_kernel/include/utility/vector_type.hpp
View file @
52423948
...
@@ -14,7 +14,7 @@ struct vector_type
...
@@ -14,7 +14,7 @@ struct vector_type
template
<
>
template
<
>
struct
vector_type
<
float
,
1
>
struct
vector_type
<
float
,
1
>
{
{
typedef
float
MemoryType
;
using
MemoryType
=
float
;
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
float
s
,
Number
<
I
>
)
...
@@ -64,6 +64,24 @@ struct vector_type<float, 4>
...
@@ -64,6 +64,24 @@ struct vector_type<float, 4>
}
}
};
};
template
<
>
struct
vector_type
<
const
float
,
1
>
{
using
MemoryType
=
const
float
;
};
template
<
>
struct
vector_type
<
const
float
,
2
>
{
using
MemoryType
=
const
float2_t
;
};
template
<
>
struct
vector_type
<
const
float
,
4
>
{
using
MemoryType
=
const
float4_t
;
};
}
// namespace ck
}
// namespace ck
#endif
#endif
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
View file @
52423948
...
@@ -107,42 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -107,42 +107,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
constexpr index_t InBlockCopyDataPer
Read
_N = 4;
constexpr index_t InBlockCopyDataPer
Access
_N = 4;
constexpr index_t WeiBlockCopyDataPer
Read
_K = 4;
constexpr index_t WeiBlockCopyDataPer
Access
_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
#elif
0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
4
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
8
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
0
,
0
,
0
,
0
>
;
// not used
constexpr
index_t
InBlockCopyDataPerRead_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
#elif
1
#elif
1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
...
@@ -170,43 +139,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -170,43 +139,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
4
>
;
using
InBlockCopySubLengths_CHWN
=
Sequence
<
1
,
1
,
1
,
4
>
;
constexpr
index_t
InBlockCopyDataPerRead_N
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
4
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
using
WeiBlockCopySubLengths_CK
=
Sequence
<
2
,
4
>
;
using
WeiBlockCopyClusterLengths_CK
=
Sequence
<
4
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
2
;
#elif 0
// for 3x3, 34x34, v1r3, Pascal, bad
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
1
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
2
,
2
,
32
,
1
>
;
constexpr
index_t
InBlockCopyDataPerRead_N
=
1
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
1
;
#elif 0
#elif 0
// for 3x3, 34x34, v1r1, Vega 20
// for 3x3, 34x34, v1r1, Vega 20
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -232,12 +173,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -232,12 +173,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
4
,
2
,
8
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
4
,
2
,
8
>
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
2
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
2
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
2
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
4
;
#elif 1
#elif 1
// for 3x3, 34x34, v1r3, Vega 20
// for 3x3, 34x34, v1r3, Vega 20
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -263,12 +204,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -263,12 +204,15 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
4
,
4
>
;
using
InBlockCopySubLengths_CHWN
=
Sequence
<
1
,
1
,
1
,
4
>
;
constexpr
index_t
InBlockCopyDataPerRead_N
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
4
,
4
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
using
WeiBlockCopySubLengths_CK
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopyClusterLengths_CK
=
Sequence
<
8
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
4
;
#elif 0
#elif 0
// for 3x3, 56x56, v1r1, Pascal
// for 3x3, 56x56, v1r1, Pascal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
...
@@ -282,13 +226,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -282,13 +226,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
...
@@ -298,7 +242,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -298,7 +242,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -324,14 +268,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -324,14 +268,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
4
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -347,13 +291,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -347,13 +291,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
...
@@ -365,7 +309,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -365,7 +309,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -393,12 +337,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -393,12 +337,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
2
,
4
,
4
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
4
,
2
,
4
,
4
>
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
#elif 0
#elif 0
// for 1x1, 28x28, v1r1, Pascal
// for 1x1, 28x28, v1r1, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
...
@@ -413,13 +357,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -413,13 +357,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
...
@@ -429,7 +373,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -429,7 +373,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -453,65 +397,67 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -453,65 +397,67 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPer
Read
_N
=
4
;
constexpr
index_t
InBlockCopyDataPer
Access
_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Read
_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPer
Access
_K
=
4
;
constexpr
index_t
OutThreadCopyDataPer
Write
_N
=
2
;
constexpr
index_t
OutThreadCopyDataPer
Access
_N
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
constexpr
index_t
GridSize
=
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
(
N
/
NPerBlock
)
*
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
);
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
constexpr
auto
gridwise_conv
=
{
constexpr
auto
gridwise_conv
=
#if 0
#if 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif
0
#elif
0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 0
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
#endif
#endif
<
GridSize
,
<
GridSize
,
BlockSize
,
BlockSize
,
T
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
decltype
(
out_khwn_desc
),
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
CPerBlock
,
CPerBlock
,
HoPerBlock
,
HoPerBlock
,
WoPerBlock
,
WoPerBlock
,
NPerThread
,
NPerThread
,
KPerThread
,
KPerThread
,
HoPerThread
,
HoPerThread
,
WoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
,
GemmDataPerReadB
,
InBlockCopyClusterLengths_CHWN
,
InBlockCopySubLengths_CHWN
,
InBlockCopyDataPerRead_N
,
InBlockCopyClusterLengths_CHWN
,
WeiBlockCopyDataPerRead_K
,
InBlockCopyDataPerAccess_N
,
OutThreadCopyDataPerWrite_N
>
{};
WeiBlockCopySubLengths_CK
,
WeiBlockCopyClusterLengths_CK
,
WeiBlockCopyDataPerAccess_K
,
OutThreadCopyDataPerAccess_N
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
...
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
0 → 100644
View file @
52423948
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp"
using
namespace
ck
;
template
<
typename
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
LeftPads
,
class
RightPads
>
void
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
LeftPads
,
RightPads
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
// reorder input
auto
in_chwn_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
ostream_ConstantTensorDescriptor
(
in_chwn_desc
,
std
::
cout
<<
"in_chwn_desc: "
);
Tensor
<
T
>
in_chwn
(
make_TensorDescriptor
(
in_chwn_desc
));
auto
f_reorder_nchw2chwn
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
make_ParallelTensorFunctor
(
f_reorder_nchw2chwn
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 1
// v1r3, 3x3, 32x32, 1x1 pad
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_CHWN
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
8
>
;
constexpr
index_t
InBlockCopyDataPerAccess_N
=
4
;
using
WeiBlockCopySubLengths_CK
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopyClusterLengths_CK
=
Sequence
<
8
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerAccess_N
=
4
;
#endif
#if 1 // debug
constexpr
index_t
GridSize
=
(
N
/
NPerBlock
)
*
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
);
#else
constexpr
index_t
GridSize
=
1
;
#endif
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
LeftPads
,
RightPads
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_CHWN
,
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerAccess_N
,
WeiBlockCopySubLengths_CK
,
WeiBlockCopyClusterLengths_CK
,
WeiBlockCopyDataPerAccess_K
,
OutThreadCopyDataPerAccess_N
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
52423948
...
@@ -33,18 +33,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -33,18 +33,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
...
@@ -54,19 +47,16 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -54,19 +47,16 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
N1
=
2
;
constexpr
index_t
N2
=
4
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
#if 1
// each thread hold 64 data
//
BlockSize = 256, blockwise-GEMM 128x128,
each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
...
@@ -80,65 +70,67 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -80,65 +70,67 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2
, B
]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E,
B,
N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
8
,
32
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 0
constexpr
index_t
OutThreadCopyDataPerAccess_W
=
1
;
// BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data
#elif 1
constexpr
index_t
BlockSize
=
64
;
// each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
BPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
2
,
2
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
8
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
8
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2
, B
]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E,
B,
N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
2
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
2
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
2
,
2
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
2
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
4
,
64
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
32
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
2
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
2
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
0
#elif
1
// each thread hold 32 data
//
BlockSize = 256, blockwise-GEMM 64x128,
each thread hold 32 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
2
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
...
@@ -152,7 +144,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -152,7 +144,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2
, B
]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E,
B,
N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
...
@@ -168,57 +160,60 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -168,57 +160,60 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#endif
#endif
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
constexpr
index_t
GridSize
=
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
constexpr
auto
gridwise_conv
=
{
constexpr
auto
gridwise_conv
=
#if 0
#if 0
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
#else
#else
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
#endif
#endif
<
GridSize
,
<
GridSize
,
BlockSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvStrides
,
ConvDilations
,
ConvDilations
,
BPerBlock
,
BPerBlock
,
KPerBlock
,
KPerBlock
,
EPerBlock
,
EPerBlock
,
N1
,
GemmNRepeat
,
N2
,
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcDataPerRead_B
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
WeiBlockCopyDstDataPerWrite_K
,
OutThreadCopyDataPerAccess_W
,
ConvolutionDirection
::
BackwardWeights
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment