Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a0584426
Commit
a0584426
authored
Mar 17, 2019
by
Chao Liu
Browse files
refactoring ConstantTensorDescriptor
parent
fd8de384
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
448 additions
and
336 deletions
+448
-336
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
+28
-21
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+1
-1
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+18
-0
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+49
-266
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+92
-0
src/include/common.hip.hpp
src/include/common.hip.hpp
+4
-48
src/include/constant_integral.hip.hpp
src/include/constant_integral.hip.hpp
+12
-0
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+49
-0
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
...se_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
+195
-0
No files found.
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
View file @
a0584426
...
...
@@ -2,6 +2,7 @@
#include <unistd.h>
#include "device.hpp"
#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp"
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_direct_convolution_2_nchw_kcyx_nkhw
(
InDesc
,
...
...
@@ -57,27 +58,33 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_direct_convolution_2_nchw_kcyx_nkhw
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
,
BlockSize
,
GridSize
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
float
time
=
launch_kernel
(
#if 0
gridwise_direct_convolution_2_nchw_kcyx_nkhw
#else
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#endif
<
T
,
InDesc
,
WeiDesc
,
OutDesc
,
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
,
BlockSize
,
GridSize
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
a0584426
...
...
@@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
#if
1
#if
0
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#else
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...
...
src/include/Array.hip.hpp
0 → 100644
View file @
a0584426
#pragma once
template
<
class
TData
,
unsigned
NSize
>
struct
Array
{
using
Type
=
Array
<
TData
,
NSize
>
;
static
constexpr
unsigned
nSize
=
NSize
;
unsigned
mData
[
nSize
];
template
<
class
...
Xs
>
__host__
__device__
Array
(
Xs
...
xs
)
:
mData
({
static_cast
<
TData
>
(
xs
)...})
{
}
__host__
__device__
TData
operator
[](
unsigned
i
)
const
{
return
mData
[
i
];
}
};
src/include/ConstantTensorDescriptor.hip.hpp
View file @
a0584426
...
...
@@ -65,8 +65,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
template
<
class
Lengths
,
class
Strides
>
struct
ConstantTensorDescriptor
{
using
Type
=
ConstantTensorDescriptor
<
Lengths
,
Strides
>
;
static
constexpr
unsigned
nDim
=
Lengths
::
nDim
;
using
NDimConstant
=
Number
<
nDim
>
;
__host__
__device__
constexpr
ConstantTensorDescriptor
()
{
...
...
@@ -91,293 +91,70 @@ struct ConstantTensorDescriptor
return
Strides
{}.
Get
(
Number
<
I
>
{});
}
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
GetElementSize_f
{
static_assert
(
nDim
>=
2
&&
nDim
<=
8
,
"nDim"
);
if
(
nDim
==
2
)
template
<
class
IDim
>
__host__
__device__
constexpr
unsigned
operator
()(
IDim
idim
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
return
GetLength
(
I0
)
*
GetLength
(
I1
);
return
Type
{}.
GetLength
(
idim
);
}
else
if
(
nDim
==
3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
};
return
GetLength
(
I0
)
*
GetLength
(
I1
)
*
GetLength
(
I2
);
}
else
if
(
nDim
==
4
)
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
multiply
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
a
,
unsigned
b
)
const
{
return
a
*
b
;
}
};
return
GetLength
(
I0
)
*
GetLength
(
I1
)
*
GetLength
(
I2
)
*
GetLength
(
I3
);
}
else
if
(
nDim
==
5
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
return
static_const_reduce_n
<
nDim
>
{}(
GetElementSize_f
{},
multiply
{});
}
return
GetLength
(
I0
)
*
GetLength
(
I1
)
*
GetLength
(
I2
)
*
GetLength
(
I3
)
*
GetLength
(
I4
);
}
else
if
(
nDim
==
6
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
return
GetLength
(
I0
)
*
GetLength
(
I1
)
*
GetLength
(
I2
)
*
GetLength
(
I3
)
*
GetLength
(
I4
)
*
GetLength
(
I5
);
}
else
if
(
nDim
==
7
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
return
GetLength
(
I0
)
*
GetLength
(
I1
)
*
GetLength
(
I2
)
*
GetLength
(
I3
)
*
GetLength
(
I4
)
*
GetLength
(
I5
)
*
GetLength
(
I6
);
}
else
if
(
nDim
==
8
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
return
GetLength
(
I0
)
*
GetLength
(
I1
)
*
GetLength
(
I2
)
*
GetLength
(
I3
)
*
GetLength
(
I4
)
*
GetLength
(
I5
)
*
GetLength
(
I6
)
*
GetLength
(
I7
);
}
else
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
GetElementSpace_f
{
template
<
class
IDim
>
__host__
__device__
constexpr
unsigned
operator
()(
IDim
idim
)
const
{
assert
(
false
);
return
(
Type
{}.
GetLength
(
idim
)
-
1
)
*
Type
{}.
GetStride
(
idim
);
}
}
}
;
template
<
class
Align
=
Number
<
1
>
>
__host__
__device__
constexpr
unsigned
GetElementSpace
(
Align
align
=
Align
{})
const
{
static_assert
(
nDim
>=
2
&&
nDim
<=
8
,
"nDim"
);
constexpr
unsigned
align_size
=
align
.
Get
();
if
(
nDim
==
2
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
return
(
GetLength
(
I0
)
-
1
)
*
GetStride
(
I0
)
+
(
GetLength
(
I1
)
-
1
)
*
GetStride
(
I1
)
+
align_size
;
}
else
if
(
nDim
==
3
)
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
add
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
return
(
GetLength
(
I0
)
-
1
)
*
GetStride
(
I0
)
+
(
GetLength
(
I1
)
-
1
)
*
GetStride
(
I1
)
+
(
GetLength
(
I2
)
-
1
)
*
GetStride
(
I2
)
+
align_size
;
}
else
if
(
nDim
==
4
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
return
(
GetLength
(
I0
)
-
1
)
*
GetStride
(
I0
)
+
(
GetLength
(
I1
)
-
1
)
*
GetStride
(
I1
)
+
(
GetLength
(
I2
)
-
1
)
*
GetStride
(
I2
)
+
(
GetLength
(
I3
)
-
1
)
*
GetStride
(
I3
)
+
align_size
;
}
else
if
(
nDim
==
5
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
return
(
GetLength
(
I0
)
-
1
)
*
GetStride
(
I0
)
+
(
GetLength
(
I1
)
-
1
)
*
GetStride
(
I1
)
+
(
GetLength
(
I2
)
-
1
)
*
GetStride
(
I2
)
+
(
GetLength
(
I3
)
-
1
)
*
GetStride
(
I3
)
+
(
GetLength
(
I4
)
-
1
)
*
GetStride
(
I4
)
+
align_size
;
}
else
if
(
nDim
==
6
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
return
(
GetLength
(
I0
)
-
1
)
*
GetStride
(
I0
)
+
(
GetLength
(
I1
)
-
1
)
*
GetStride
(
I1
)
+
(
GetLength
(
I2
)
-
1
)
*
GetStride
(
I2
)
+
(
GetLength
(
I3
)
-
1
)
*
GetStride
(
I3
)
+
(
GetLength
(
I4
)
-
1
)
*
GetStride
(
I4
)
+
(
GetLength
(
I5
)
-
1
)
*
GetStride
(
I5
)
+
align_size
;
}
else
if
(
nDim
==
7
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
return
(
GetLength
(
I0
)
-
1
)
*
GetStride
(
I0
)
+
(
GetLength
(
I1
)
-
1
)
*
GetStride
(
I1
)
+
(
GetLength
(
I2
)
-
1
)
*
GetStride
(
I2
)
+
(
GetLength
(
I3
)
-
1
)
*
GetStride
(
I3
)
+
(
GetLength
(
I4
)
-
1
)
*
GetStride
(
I4
)
+
(
GetLength
(
I5
)
-
1
)
*
GetStride
(
I5
)
+
(
GetLength
(
I6
)
-
1
)
*
GetStride
(
I6
)
+
align_size
;
}
else
if
(
nDim
==
8
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
return
(
GetLength
(
I0
)
-
1
)
*
GetStride
(
I0
)
+
(
GetLength
(
I1
)
-
1
)
*
GetStride
(
I1
)
+
(
GetLength
(
I2
)
-
1
)
*
GetStride
(
I2
)
+
(
GetLength
(
I3
)
-
1
)
*
GetStride
(
I3
)
+
(
GetLength
(
I4
)
-
1
)
*
GetStride
(
I4
)
+
(
GetLength
(
I5
)
-
1
)
*
GetStride
(
I5
)
+
(
GetLength
(
I6
)
-
1
)
*
GetStride
(
I6
)
+
(
GetLength
(
I7
)
-
1
)
*
GetStride
(
I7
)
+
align_size
;
}
}
// this is ugly, only for 2d
__host__
__device__
unsigned
Get1dIndex
(
unsigned
i0
,
unsigned
i1
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
nDim
==
2
,
"nDim is not 2"
);
return
i0
*
GetStride
(
I0
)
+
i1
*
GetStride
(
I1
);
}
// this is ugly, only for 3d
__host__
__device__
unsigned
Get1dIndex
(
unsigned
i0
,
unsigned
i1
,
unsigned
i2
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
a
,
unsigned
b
)
const
{
return
a
+
b
;
}
};
static_assert
(
nDim
==
3
,
"nDim is not 3"
);
return
i0
*
GetStride
(
I0
)
+
i1
*
GetStride
(
I1
)
+
i2
*
GetStride
(
I2
);
return
static_const_reduce_n
<
nDim
>
{}(
GetElementSpace_f
{},
add
{})
+
align
.
Get
();
}
// this is ugly, only for 4d
__host__
__device__
unsigned
Get1dIndex
(
unsigned
i0
,
unsigned
i1
,
unsigned
i2
,
unsigned
i3
)
const
template
<
class
...
Is
>
__host__
__device__
unsigned
Get1dIndex
(
Is
...
is
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
nDim
==
4
,
"nDim is not 4"
);
return
i0
*
GetStride
(
I0
)
+
i1
*
GetStride
(
I1
)
+
i2
*
GetStride
(
I2
)
+
i3
*
GetStride
(
I3
);
}
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"number of multi-index is wrong"
);
// this is ugly, only for 5d
__host__
__device__
unsigned
Get1dIndex
(
unsigned
i0
,
unsigned
i1
,
unsigned
i2
,
unsigned
i3
,
unsigned
i4
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
multi_id
=
Array
<
unsigned
,
nDim
>
(
is
...);
static_assert
(
nDim
==
5
,
"nDim is not 5"
);
return
i0
*
GetStride
(
I0
)
+
i1
*
GetStride
(
I1
)
+
i2
*
GetStride
(
I2
)
+
i3
*
GetStride
(
I3
)
+
i4
*
GetStride
(
I4
);
}
unsigned
id
=
0
;
// this is ugly, only for 6d
__host__
__device__
unsigned
Get1dIndex
(
unsigned
i0
,
unsigned
i1
,
unsigned
i2
,
unsigned
i3
,
unsigned
i4
,
unsigned
i5
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
static_loop_n
<
nDim
>
{}([
&
](
auto
IDim
)
{
constexpr
unsigned
idim
=
IDim
.
Get
();
id
+=
multi_id
[
idim
]
*
GetStride
(
IDim
);
});
static_assert
(
nDim
==
6
,
"nDim is not 6"
);
return
i0
*
GetStride
(
I0
)
+
i1
*
GetStride
(
I1
)
+
i2
*
GetStride
(
I2
)
+
i3
*
GetStride
(
I3
)
+
i4
*
GetStride
(
I4
)
+
i5
*
GetStride
(
I5
);
}
// this is ugly, only for 7d
__host__
__device__
unsigned
Get1dIndex
(
unsigned
i0
,
unsigned
i1
,
unsigned
i2
,
unsigned
i3
,
unsigned
i4
,
unsigned
i5
,
unsigned
i6
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
static_assert
(
nDim
==
7
,
"nDim is not 7"
);
return
i0
*
GetStride
(
I0
)
+
i1
*
GetStride
(
I1
)
+
i2
*
GetStride
(
I2
)
+
i3
*
GetStride
(
I3
)
+
i4
*
GetStride
(
I4
)
+
i5
*
GetStride
(
I5
)
+
i6
*
GetStride
(
I6
);
}
// this is ugly, only for 8d
__host__
__device__
unsigned
Get1dIndex
(
unsigned
i0
,
unsigned
i1
,
unsigned
i2
,
unsigned
i3
,
unsigned
i4
,
unsigned
i5
,
unsigned
i6
,
unsigned
i7
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
static_assert
(
nDim
==
8
,
"nDim is not 8"
);
return
i0
*
GetStride
(
I0
)
+
i1
*
GetStride
(
I1
)
+
i2
*
GetStride
(
I2
)
+
i3
*
GetStride
(
I3
)
+
i4
*
GetStride
(
I4
)
+
i5
*
GetStride
(
I5
)
+
i6
*
GetStride
(
I6
)
+
i7
*
GetStride
(
I7
);
return
id
;
}
__host__
__device__
constexpr
auto
Condense
()
const
...
...
@@ -385,6 +162,12 @@ struct ConstantTensorDescriptor
constexpr
auto
default_strides
=
calculate_default_strides
(
Lengths
{});
return
ConstantTensorDescriptor
<
Lengths
,
decltype
(
default_strides
)
>
{};
}
template
<
unsigned
IDim
,
unsigned
NVector
>
__host__
__device__
constexpr
auto
Vectorize
(
Number
<
IDim
>
,
Number
<
NVector
>
)
const
{
assert
(
false
);
// not implemented
}
};
template
<
class
Lengths
>
...
...
src/include/Sequence.hip.hpp
0 → 100644
View file @
a0584426
#pragma once
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
template
<
unsigned
...
Is
>
struct
Sequence
{
using
Type
=
Sequence
<
Is
...
>
;
static
constexpr
unsigned
nDim
=
sizeof
...(
Is
);
const
unsigned
mData
[
nDim
]
=
{
Is
...};
template
<
unsigned
I
>
__host__
__device__
constexpr
unsigned
Get
(
Number
<
I
>
)
const
{
return
mData
[
I
];
}
// this is ugly, only for nDIm = 4
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
__host__
__device__
constexpr
auto
ReorderByGetNewFromOld
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
static_assert
(
nDim
==
4
,
"nDim != 4"
);
constexpr
auto
old_sequence
=
Type
{};
constexpr
unsigned
NR0
=
old_sequence
.
mData
[
I0
];
constexpr
unsigned
NR1
=
old_sequence
.
mData
[
I1
];
constexpr
unsigned
NR2
=
old_sequence
.
mData
[
I2
];
constexpr
unsigned
NR3
=
old_sequence
.
mData
[
I3
];
return
Sequence
<
NR0
,
NR1
,
NR2
,
NR3
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
__host__
__device__
constexpr
auto
ReorderByPutOldToNew
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
// don't know how to implement this
printf
(
"Sequence::ReorderByPutOldToNew not implemented"
);
assert
(
false
);
}
template
<
unsigned
I
>
__host__
__device__
constexpr
auto
PushBack
(
Number
<
I
>
)
const
{
return
Sequence
<
Is
...,
I
>
{};
}
__host__
__device__
constexpr
auto
PopBack
()
const
;
template
<
class
F
>
__host__
__device__
constexpr
auto
Transform
(
F
f
)
const
{
return
Sequence
<
f
(
Is
)...
>
{};
}
};
template
<
unsigned
...
Is
,
unsigned
I
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
Is
...,
I
>
)
{
static_assert
(
sizeof
...(
Is
)
>=
1
,
"empty Sequence!"
);
return
Sequence
<
Is
...
>
{};
}
template
<
class
F
,
unsigned
...
Xs
,
unsigned
...
Ys
>
__host__
__device__
constexpr
auto
sequence_sequence_op
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
F
f
)
{
static_assert
(
Sequence
<
Xs
...
>::
nDim
==
Sequence
<
Ys
...
>::
nDim
,
"Dim not the same"
);
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
template
<
unsigned
...
Xs
,
unsigned
...
Ys
>
__host__
__device__
constexpr
auto
sequence_sequence_add
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
struct
add
{
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
x
,
unsigned
y
)
const
{
return
x
+
y
;
}
};
return
sequence_sequence_op
(
Sequence
<
Xs
...
>
{},
Sequence
<
Ys
...
>
{},
add
{});
}
template
<
unsigned
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopBack
()
const
{
return
sequence_pop_back
(
Type
{});
}
src/include/common.hip.hpp
View file @
a0584426
#pragma once
#include "constant_integral.hip.hpp"
#include "Sequence.hip.hpp"
#include "Array.hip.hpp"
#include "functional.hip.hpp"
__device__
unsigned
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
...
...
@@ -91,54 +95,6 @@ struct vector_type<half, 8>
};
#endif
template
<
class
T
,
T
N
>
struct
integral_constant
{
static
const
T
value
=
N
;
__host__
__device__
constexpr
T
Get
()
const
{
return
value
;
}
};
template
<
unsigned
N
>
using
Number
=
integral_constant
<
unsigned
,
N
>
;
template
<
unsigned
...
Is
>
struct
Sequence
{
using
Type
=
Sequence
<
Is
...
>
;
static
constexpr
unsigned
nDim
=
sizeof
...(
Is
);
const
unsigned
mData
[
nDim
]
=
{
Is
...};
template
<
unsigned
I
>
__host__
__device__
constexpr
unsigned
Get
(
Number
<
I
>
)
const
{
return
mData
[
I
];
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
__host__
__device__
constexpr
auto
ReorderByGetNewFromOld
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
constexpr
auto
old_sequence
=
Type
{};
constexpr
unsigned
NR0
=
old_sequence
.
mData
[
I0
];
constexpr
unsigned
NR1
=
old_sequence
.
mData
[
I1
];
constexpr
unsigned
NR2
=
old_sequence
.
mData
[
I2
];
constexpr
unsigned
NR3
=
old_sequence
.
mData
[
I3
];
return
Sequence
<
NR0
,
NR1
,
NR2
,
NR3
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
__host__
__device__
constexpr
auto
ReorderByPutOldToNew
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
// don't know how to implement this
printf
(
"Sequence::ReorderByPutOldToNew not implemented"
);
assert
(
false
);
}
};
template
<
typename
T
>
__host__
__device__
constexpr
T
max
(
T
a
,
T
b
)
{
...
...
src/include/constant_integral.hip.hpp
0 → 100644
View file @
a0584426
#pragma once
template
<
class
T
,
T
N
>
struct
integral_constant
{
static
const
T
value
=
N
;
__host__
__device__
constexpr
T
Get
()
const
{
return
value
;
}
};
template
<
unsigned
N
>
using
Number
=
integral_constant
<
unsigned
,
N
>
;
src/include/functional.hip.hpp
0 → 100644
View file @
a0584426
#pragma once
#include "constant_integral.hip.hpp"
template
<
unsigned
NLoop
>
struct
static_loop_n
{
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
{
static_assert
(
NLoop
>
1
,
"out-of-range"
);
f
(
Number
<
NLoop
-
1
>
{});
static_loop_n
<
NLoop
-
1
>
{}(
f
);
}
};
template
<
>
struct
static_loop_n
<
1
>
{
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
{
f
(
Number
<
0
>
{});
}
};
template
<
unsigned
NLoop
>
struct
static_const_reduce_n
{
template
<
class
F
,
class
Reduce
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
Reduce
r
)
const
{
static_assert
(
NLoop
>
1
,
"out-of-range"
);
constexpr
auto
a
=
f
(
Number
<
NLoop
-
1
>
{});
auto
b
=
static_const_reduce_n
<
NLoop
-
1
>
{}(
f
,
r
);
// cannot use constexpr here, weird
return
r
(
a
,
b
);
}
};
template
<
>
struct
static_const_reduce_n
<
1
>
{
template
<
class
F
,
class
Reduce
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
Reduce
)
const
{
return
f
(
Number
<
0
>
{});
}
};
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
0 → 100644
View file @
a0584426
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_direct_convolution.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template
<
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
unsigned
NPerBlock
,
unsigned
KPerBlock
,
unsigned
CPerBlock
,
unsigned
HoPerBlock
,
unsigned
WoPerBlock
,
unsigned
NPerThread
,
unsigned
KPerThread
,
unsigned
CPerThread
,
unsigned
HoPerThread
,
unsigned
WoPerThread
,
unsigned
BlockSize
,
unsigned
GridSize
>
__global__
void
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_global_desc
=
OutGlobalDesc
{};
constexpr
unsigned
Y
=
wei_global_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_global_desc
.
GetLength
(
I3
);
constexpr
unsigned
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
unsigned
WiPerBlock
=
WoPerBlock
+
X
-
1
;
constexpr
auto
in_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerBlock
,
CPerBlock
,
HiPerBlock
,
WiPerBlock
>
{});
constexpr
auto
wei_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
,
CPerBlock
,
Y
,
X
>
{});
// shared mem
constexpr
unsigned
in_block_size
=
in_block_desc
.
GetElementSpace
();
constexpr
unsigned
wei_block_size
=
wei_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
// threadwise tensors
constexpr
unsigned
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
unsigned
WiPerThread
=
WoPerThread
+
X
-
1
;
constexpr
auto
in_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
HiPerThread
,
WiPerThread
>
{},
in_block_desc
.
GetStrides
());
constexpr
auto
wei_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{},
wei_block_desc
.
GetStrides
());
constexpr
auto
out_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_thread_block_desc
,
wei_thread_block_desc
);
// register
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
// divide block work
constexpr
unsigned
NBlockWork
=
(
out_global_desc
.
GetLength
(
I0
)
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
unsigned
KBlockWork
=
(
out_global_desc
.
GetLength
(
I1
)
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
unsigned
HBlockWork
=
(
out_global_desc
.
GetLength
(
I2
)
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
unsigned
WBlockWork
=
(
out_global_desc
.
GetLength
(
I3
)
+
WoPerBlock
-
1
)
/
WoPerBlock
;
const
unsigned
block_id
=
blockIdx
.
x
;
unsigned
itmp
=
block_id
;
const
unsigned
n_block_work_id
=
itmp
/
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
itmp
-=
n_block_work_id
*
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
const
unsigned
k_block_work_id
=
itmp
/
(
HBlockWork
*
WBlockWork
);
itmp
-=
k_block_work_id
*
(
HBlockWork
*
WBlockWork
);
const
unsigned
h_block_work_id
=
itmp
/
WBlockWork
;
const
unsigned
w_block_work_id
=
itmp
-
h_block_work_id
*
WBlockWork
;
const
unsigned
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
unsigned
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
unsigned
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
unsigned
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
unsigned
hi_block_data_begin
=
ho_block_data_begin
;
// minus padding
const
unsigned
wi_block_data_begin
=
wo_block_data_begin
;
// minus padding
// divide thread work
constexpr
unsigned
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
unsigned
KThreadWork
=
(
KPerBlock
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
unsigned
HThreadWork
=
(
HoPerBlock
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
unsigned
WThreadWork
=
(
WoPerBlock
+
WoPerThread
-
1
)
/
WoPerThread
;
const
unsigned
thread_id
=
threadIdx
.
x
;
itmp
=
thread_id
;
const
unsigned
n_thread_work_id
=
itmp
/
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
const
unsigned
k_thread_work_id
=
itmp
/
(
HThreadWork
*
WThreadWork
);
itmp
-=
k_thread_work_id
*
(
HThreadWork
*
WThreadWork
);
const
unsigned
h_thread_work_id
=
itmp
/
WThreadWork
;
const
unsigned
w_thread_work_id
=
itmp
-
h_thread_work_id
*
WThreadWork
;
const
unsigned
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
const
unsigned
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
const
unsigned
ho_thread_data_begin
=
h_thread_work_id
*
HoPerThread
;
const
unsigned
wo_thread_data_begin
=
w_thread_work_id
*
WoPerThread
;
const
unsigned
hi_thread_data_begin
=
ho_thread_data_begin
;
const
unsigned
wi_thread_data_begin
=
wo_thread_data_begin
;
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
in_global_desc
),
decltype
(
in_block_desc
),
decltype
(
in_block_desc
.
GetLengths
())
>
{};
constexpr
auto
blockwise_wei_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
wei_global_desc
),
decltype
(
wei_block_desc
),
decltype
(
wei_block_desc
.
GetLengths
())
>
{};
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_thread_desc
,
p_out_thread
);
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
in_global_desc
.
GetLength
(
I1
);
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
p_in_block
);
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_global
+
wei_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_block
);
__syncthreads
();
for
(
unsigned
c_thread_data
=
0
;
c_thread_data
<
CPerBlock
;
c_thread_data
+=
CPerThread
)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2
(
in_thread_block_desc
,
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_thread_block_desc
,
p_wei_block
+
wei_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_thread_desc
,
p_out_thread
);
#elif 0
threadwise_direct_convolution_3
(
in_thread_block_desc
,
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_thread_block_desc
,
p_wei_block
+
wei_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_thread_desc
,
p_out_thread
);
#endif
}
}
// copy output tensor from register to global mem
threadwise_4d_tensor_copy
(
out_thread_desc
,
p_out_thread
,
out_global_desc
,
p_out_global
+
out_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_thread_desc
.
GetLengths
());
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment