Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
96ee9571
Commit
96ee9571
authored
Apr 10, 2019
by
Chao Liu
Browse files
tuned implicit gemm v1 for 3x3 on AMD to 82%. Fixed a bug in 4d tensor blockwise copy.
parent
edc89778
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
163 additions
and
26 deletions
+163
-26
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
+35
-2
driver/driver.hip.cpp
driver/driver.hip.cpp
+4
-4
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+12
-12
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+88
-0
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+1
-0
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
...dwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
+23
-8
No files found.
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
View file @
96ee9571
...
...
@@ -78,7 +78,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
// for 3x3, 34x34
// for 3x3, 34x34
, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
...
...
@@ -111,6 +111,39 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif
1
// for 3x3, 34x34, Vega 20
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
256
;
#elif 0
// for 5x5, 36x36
constexpr
index_t
NPerBlock
=
16
;
...
...
@@ -264,7 +297,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif
1
#elif
0
// for 3x3, 28x28, v1, Pacal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
...
...
driver/driver.hip.cpp
View file @
96ee9571
...
...
@@ -409,13 +409,13 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif
0
#elif
1
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
K
=
64
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
...
...
@@ -511,7 +511,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif
1
#elif
0
// 3x3 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
...
...
@@ -681,7 +681,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif
0
#elif
1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
96ee9571
...
...
@@ -646,6 +646,9 @@ struct Blockwise4dTensorCopy3
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
,
nloop_d3
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
...
...
@@ -664,13 +667,10 @@ struct Blockwise4dTensorCopy3
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
clipboard_offset
=
clipboard_desc
.
Get1dIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
,
iloop_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
dst
_offset
]))
=
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
clipboard
_offset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
...
...
@@ -713,6 +713,9 @@ struct Blockwise4dTensorCopy3
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
,
nloop_d3
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
...
...
@@ -725,11 +728,8 @@ struct Blockwise4dTensorCopy3
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
clipboard_offset
=
clipboard_desc
.
Get1dIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
,
iloop_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
...
...
@@ -738,7 +738,7 @@ struct Blockwise4dTensorCopy3
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
src
_offset
]));
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
clipboard
_offset
]));
}
}
}
...
...
src/include/blockwise_batched_gemm.hip.hpp
View file @
96ee9571
...
...
@@ -263,6 +263,94 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
#if DEVICE_BACKEND_HIP
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>::
value
&&
is_same
<
FloatB
,
float
>::
value
&&
is_same
<
FloatC
,
float
>::
value
,
"Run_asm only deal with float
\n
"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_asm cannot deal with this GEMM shape yet
\n
"
);
static_assert
(
BlockMatrixStrideA
==
0
&&
BatchPerThread
==
1
,
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now
\n
"
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
#endif
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
...
...
src/include/blockwise_gemm.hip.hpp
View file @
96ee9571
...
...
@@ -127,6 +127,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#if DEVICE_BACKEND_HIP
// TODO: this is not working correctly
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
...
...
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
View file @
96ee9571
...
...
@@ -204,23 +204,38 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
// input: global mem to LDS
#if 1
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
// weight: global mem to LDS
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
#else
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block
);
#endif
__syncthreads
();
// a series of batched GEMM
#pragma unroll
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
#pragma unroll
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
#if 1
blockwise_batch_gemm
.
Run
#else
blockwise_batch_gemm
.
Run_asm
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment