Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fdaaaa50
Commit
fdaaaa50
authored
Mar 22, 2019
by
Chao Liu
Browse files
Merge branch 'direct_fp16'
parents
2c9b8c24
18a81e35
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
50 deletions
+39
-50
src/include/threadwise_4d_tensor_op.hip.hpp
src/include/threadwise_4d_tensor_op.hip.hpp
+19
-13
src/include/threadwise_direct_convolution.hip.hpp
src/include/threadwise_direct_convolution.hip.hpp
+18
-35
src/include/threadwise_nd_tensor_op.hip.hpp
src/include/threadwise_nd_tensor_op.hip.hpp
+2
-2
No files found.
src/include/threadwise_4d_tensor_op.hip.hpp
View file @
fdaaaa50
...
@@ -37,7 +37,8 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
...
@@ -37,7 +37,8 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
// TODO: in order to optimize mem access for different mem type,
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
// need to write specialized version
template
<
class
Float
,
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
SrcOpLengths
,
...
@@ -45,9 +46,9 @@ template <class Float,
...
@@ -45,9 +46,9 @@ template <class Float,
class
F
>
class
F
>
__device__
void
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
__device__
void
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
SrcDesc
,
const
Flo
at
*
__restrict__
p_src
,
const
SrcD
at
a
*
__restrict__
p_src
,
DstDesc
,
DstDesc
,
Flo
at
*
__restrict__
p_dst
,
DstD
at
a
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
DstFromSrcReorder
,
DstFromSrcReorder
,
F
f
)
F
f
)
...
@@ -88,33 +89,38 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
...
@@ -88,33 +89,38 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
}
}
}
}
template
<
class
Flo
at
,
class
Desc
>
template
<
class
D
at
a
,
class
Desc
>
__device__
void
threadwise_4d_tensor_set_zero
(
Desc
,
Flo
at
*
__restrict__
p
)
__device__
void
threadwise_4d_tensor_set_zero
(
Desc
,
D
at
a
*
__restrict__
p
)
{
{
auto
f_set_zero
=
[](
Flo
at
&
v
)
{
v
=
Flo
at
(
0
);
};
auto
f_set_zero
=
[](
D
at
a
&
v
)
{
v
=
D
at
a
(
0
);
};
threadwise_4d_tensor_pointwise_operation_unary
<
Flo
at
,
Desc
,
decltype
(
f_set_zero
)
>
(
threadwise_4d_tensor_pointwise_operation_unary
<
D
at
a
,
Desc
,
decltype
(
f_set_zero
)
>
(
Desc
{},
p
,
f_set_zero
);
Desc
{},
p
,
f_set_zero
);
}
}
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
>
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
>
__device__
void
__device__
void
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Flo
at
*
__restrict__
p_src
,
const
SrcD
at
a
*
__restrict__
p_src
,
DstDesc
,
DstDesc
,
Flo
at
*
__restrict__
p_dst
,
DstD
at
a
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
DstFromSrcReorder
)
DstFromSrcReorder
)
{
{
auto
f_copy
=
[](
const
Flo
at
&
src
,
Flo
at
&
dst
)
{
dst
=
src
;
};
auto
f_copy
=
[](
const
SrcD
at
a
&
src
,
DstD
at
a
&
dst
)
{
dst
=
static_cast
<
DstData
>
(
src
)
;
};
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
}
}
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
__device__
void
threadwise_4d_tensor_copy
(
__device__
void
threadwise_4d_tensor_copy
(
SrcDesc
,
const
Flo
at
*
__restrict__
p_src
,
DstDesc
,
Flo
at
*
__restrict__
p_dst
,
SrcOpLengths
)
SrcDesc
,
const
SrcD
at
a
*
__restrict__
p_src
,
DstDesc
,
DstD
at
a
*
__restrict__
p_dst
,
SrcOpLengths
)
{
{
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
,
2
,
3
>
{};
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
,
2
,
3
>
{};
...
...
src/include/threadwise_direct_convolution.hip.hpp
View file @
fdaaaa50
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
// optimized for scenario if p_in, p_wei, p_out are in register
// optimized for scenario if p_in, p_wei, p_out are in register
template
<
class
Floa
t
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
TInWei
,
class
TOu
t
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_1
(
InDesc
,
__device__
void
threadwise_direct_convolution_1
(
InDesc
,
Float
*
const
__restrict__
p_in
,
TInWei
*
const
__restrict__
p_in
,
WeiDesc
,
WeiDesc
,
Float
*
const
__restrict__
p_wei
,
TInWei
*
const
__restrict__
p_wei
,
OutDesc
,
OutDesc
,
Floa
t
*
__restrict__
p_out
)
TOu
t
*
__restrict__
p_out
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -51,25 +51,8 @@ __device__ void threadwise_direct_convolution_1(InDesc,
...
@@ -51,25 +51,8 @@ __device__ void threadwise_direct_convolution_1(InDesc,
const
unsigned
out_index
=
out_desc
.
Get1dIndex
(
n
,
k
,
ho
,
wo
);
const
unsigned
out_index
=
out_desc
.
Get1dIndex
(
n
,
k
,
ho
,
wo
);
p_out
[
out_index
]
+=
p_wei
[
wei_index
]
*
p_in
[
in_index
];
fused_multiply_accumulate
(
p_out
[
out_index
],
p_wei
[
wei_index
],
p_in
[
in_index
]);
#if 0
// if(threadIdx.x == 0)
{
printf("threadwise_direct_convolution: \t"
"threadIdx.x %u\t"
"out_index %u, p_out[out_index] %f, \t"
"wei_index %u, p_wei[wei_index] %f, \t"
"in_index %u, p_in[in_index] %f\n",
threadIdx.x,
out_index,
p_out[out_index],
wei_index,
p_wei[wei_index],
in_index,
p_in[in_index]);
}
#endif
}
}
}
}
}
}
...
@@ -81,13 +64,13 @@ __device__ void threadwise_direct_convolution_1(InDesc,
...
@@ -81,13 +64,13 @@ __device__ void threadwise_direct_convolution_1(InDesc,
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
// Copy in and wei into register before doing convolution
// Copy in and wei into register before doing convolution
template
<
class
Floa
t
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
TInWei
,
class
TOu
t
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_2
(
InDesc
,
__device__
void
threadwise_direct_convolution_2
(
InDesc
,
Float
*
const
__restrict__
p_in
,
TInWei
*
const
__restrict__
p_in
,
WeiDesc
,
WeiDesc
,
Float
*
const
__restrict__
p_wei
,
TInWei
*
const
__restrict__
p_wei
,
OutDesc
,
OutDesc
,
Floa
t
*
__restrict__
p_out
)
TOu
t
*
__restrict__
p_out
)
{
{
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
...
@@ -97,8 +80,8 @@ __device__ void threadwise_direct_convolution_2(InDesc,
...
@@ -97,8 +80,8 @@ __device__ void threadwise_direct_convolution_2(InDesc,
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
wei_desc
.
GetLengths
());
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
wei_desc
.
GetLengths
());
// register
// register
Float
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
TInWei
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
Float
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
TInWei
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
// copy input tensor into register
// copy input tensor into register
threadwise_4d_tensor_copy
(
in_desc
,
p_in
,
in_reg_desc
,
p_in_reg
,
in_reg_desc
.
GetLengths
());
threadwise_4d_tensor_copy
(
in_desc
,
p_in
,
in_reg_desc
,
p_in_reg
,
in_reg_desc
.
GetLengths
());
...
@@ -114,13 +97,13 @@ __device__ void threadwise_direct_convolution_2(InDesc,
...
@@ -114,13 +97,13 @@ __device__ void threadwise_direct_convolution_2(InDesc,
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
// load 1x1 weight into register, and do 1x1 convolution in register.
// load 1x1 weight into register, and do 1x1 convolution in register.
template
<
class
Flo
at
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
D
at
a
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution_3
(
InDesc
,
__device__
void
threadwise_direct_convolution_3
(
InDesc
,
Flo
at
*
const
__restrict__
p_in
,
D
at
a
*
const
__restrict__
p_in
,
WeiDesc
,
WeiDesc
,
Flo
at
*
const
__restrict__
p_wei
,
D
at
a
*
const
__restrict__
p_wei
,
OutDesc
,
OutDesc
,
Flo
at
*
__restrict__
p_out
)
D
at
a
*
__restrict__
p_out
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -139,8 +122,8 @@ __device__ void threadwise_direct_convolution_3(InDesc,
...
@@ -139,8 +122,8 @@ __device__ void threadwise_direct_convolution_3(InDesc,
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
wei_reg_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
wei_desc
.
GetLength
(
I0
),
wei_desc
.
GetLength
(
I1
),
1
,
1
>
{});
Sequence
<
wei_desc
.
GetLength
(
I0
),
wei_desc
.
GetLength
(
I1
),
1
,
1
>
{});
Flo
at
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
D
at
a
p_in_reg
[
in_reg_desc
.
GetElementSpace
()];
Flo
at
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
D
at
a
p_wei_reg
[
wei_reg_desc
.
GetElementSpace
()];
constexpr
unsigned
in_w_new_read
=
1
;
constexpr
unsigned
in_w_new_read
=
1
;
...
...
src/include/threadwise_nd_tensor_op.hip.hpp
View file @
fdaaaa50
...
@@ -10,7 +10,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
...
@@ -10,7 +10,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
SrcOpLengths
,
SrcOpLengths
,
Number
<
DataPerRead
>
)
Number
<
DataPerRead
>
)
{
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
t
ype
;
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryT
ype
;
static_assert
(
SrcDesc
{}.
GetDimension
()
==
6
&&
DstDesc
{}.
GetDimension
()
==
6
&&
static_assert
(
SrcDesc
{}.
GetDimension
()
==
6
&&
DstDesc
{}.
GetDimension
()
==
6
&&
SrcOpLengths
::
nDim
==
6
,
SrcOpLengths
::
nDim
==
6
,
...
@@ -80,7 +80,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
...
@@ -80,7 +80,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
SrcOpLengths
,
SrcOpLengths
,
Number
<
DataPerRead
>
)
Number
<
DataPerRead
>
)
{
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
t
ype
;
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryT
ype
;
static_assert
(
SrcDesc
{}.
GetDimension
()
==
8
&&
DstDesc
{}.
GetDimension
()
==
8
&&
static_assert
(
SrcDesc
{}.
GetDimension
()
==
8
&&
DstDesc
{}.
GetDimension
()
==
8
&&
SrcOpLengths
::
nDim
==
8
,
SrcOpLengths
::
nDim
==
8
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment