Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
8fcf3f1e
Commit
8fcf3f1e
authored
May 23, 2019
by
Chao Liu
Browse files
added implicit GEMM v3
parent
8ce14804
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
100 additions
and
106 deletions
+100
-106
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
...er/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
+3
-3
driver/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
...er/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
+0
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
...ise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
+76
-93
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp
...dwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp
+21
-9
No files found.
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
View file @
8fcf3f1e
...
@@ -57,7 +57,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
...
@@ -57,7 +57,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 34x34, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t BlockSize = 128;
...
@@ -127,7 +127,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
...
@@ -127,7 +127,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
4
;
#elif
0
#elif
1
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -313,7 +313,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
...
@@ -313,7 +313,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
#if
0
#if
1
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#else
#else
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
...
driver/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
View file @
8fcf3f1e
...
@@ -62,7 +62,6 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
...
@@ -62,7 +62,6 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
#if 1
// for 3x3, 28x28, v3
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
BPerBlock
=
16
;
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
View file @
8fcf3f1e
...
@@ -83,7 +83,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -83,7 +83,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
constexpr
index_t
HBlockWork
=
mod_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
HBlockWork
=
mod_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
mod_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
index_t
WBlockWork
=
mod_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
_default_rank_packed
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
const
auto
block_work_multi_id
=
const
auto
block_work_multi_id
=
...
@@ -99,7 +99,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -99,7 +99,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// global tensor view
// global tensor view
constexpr
auto
wei_c_k_global_desc
=
constexpr
auto
wei_c_k_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
make_ConstantTensorDescriptor
_default_rank
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
...
@@ -108,7 +108,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -108,7 +108,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_
default_rank_
aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockReorderDataPerWrite_N
>
{});
Number
<
InBlockReorderDataPerWrite_N
>
{});
...
@@ -117,12 +117,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -117,12 +117,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_
default_rank_
aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
_default_rank_packed
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// blockwise copy
...
@@ -140,7 +140,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -140,7 +140,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
decltype
(
map_chwn2nchw
),
decltype
(
map_chwn2nchw
),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
>
{}
;
InBlockReorderDataPerWrite_N
>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
})
;
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
// format is [CPerBlock, KPerBlock]
...
@@ -150,7 +150,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -150,7 +150,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
{}
;
WeiBlockCopyDataPerRead_K
>
({
0
,
0
},
{
0
,
0
})
;
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -194,7 +194,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -194,7 +194,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// choose GEMM implementation here
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if
0
#if
1
return
blockwise_batch_gemm
.
Run
(
Xs
...);
return
blockwise_batch_gemm
.
Run
(
Xs
...);
#elif 0
#elif 0
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
...
@@ -249,7 +249,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -249,7 +249,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
{
{
for(index_t x = 0; x < X; ++x)
for(index_t x = 0; x < X; ++x)
{
{
#if 1
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_block);
p_in_block);
...
@@ -257,23 +256,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -257,23 +256,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
blockwise_wei_copy.Run(p_wei_global_block_offset +
blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_block);
p_wei_block);
#else
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard(
p_in_global_block_offset + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(
p_wei_global_block_offset + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_clipboard);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
#endif
__syncthreads();
__syncthreads();
...
@@ -304,24 +286,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -304,24 +286,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
p_wei_global_block_offset
+=
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
{
#if 0
blockwise_in_copy_reorder
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_in_copy_reorder.Run(p_in_global_block_offset,
p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset,
p_wei_block);
#else
Float
p_in_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_clipboard
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_clipboard
,
p_wei_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_clipboard
,
p_in_block
);
#endif
__syncthreads
();
__syncthreads
();
...
@@ -342,13 +309,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -342,13 +309,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
f_dummy
)
{
// f_dummy do nothing but
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
fwd
)
{
// perfect forwarding.
// fwd do nothing but perfect forwarding.
// Using this trick to
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
// make this lambda a generic lambda, so it won't be compiled until
// begin instantiated here
// instantiated
static_assert
(
static_assert
(
(
f
_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
(
f
wd
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
"wrong!"
);
// output is a 10d tensor
// output is a 10d tensor
...
@@ -356,38 +322,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -356,38 +322,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f
_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f
wd
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
constexpr
auto
out_10d_global_desc
=
fwd
(
out_n_k_h_w_global_desc
)
make_ConstantTensorDescriptor
(
Sequence
<
N
/
f_dummy
(
N1
*
N2
),
.
Fold
(
I3
,
Number
<
W1
>
{},
Number
<
W2
>
{})
N1
,
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
N2
,
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
K
/
(
K1
*
K2
),
K1
,
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
K2
,
.
Fold
(
I3
,
Number
<
1
>
{},
Number
<
N2
>
{})
Ho
,
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{})
Wo
/
(
W1
*
W2
),
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
W1
,
W2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"
out_k_h_w_n_thread_desc");
"a:
out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "
a:
out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_
n_
global_desc,
print_ConstantTensorDescriptor(out_
n_
k_h_w_global_desc,
"
out_k_h_w_
n_
global_desc");
"a:
out_
n_
k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "
a:
out_10d_global_desc");
}
}
#endif
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
...
@@ -405,8 +366,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -405,8 +366,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
// Number<OutThreadCopyDataPerWrite_W>{});
}).
else_
([
&
](
auto
f
_dummy
)
{
}).
else_
([
&
](
auto
f
wd
)
{
static_assert
(
f
_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
f
wd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -415,34 +376,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -415,34 +376,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f
_dummy
(
W2
*
W3
);
constexpr
index_t
W1
=
WoPerBlock
/
f
wd
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_10d_global_desc
=
Sequence
<
N
/
N1
,
N1
,
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
>
{});
fwd
(
out_n_k_h_w_global_desc
)
.
Fold
(
I3
,
Number
<
W1
>
{},
Number
<
W2
>
{},
Number
<
W3
>
{})
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_10d_thread_desc
=
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{},
Number
<
W3
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
"b: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
"b: out_n_k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
}
#endif
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
threadwise_tensor_slice_copy_reorder_given_dst2src_v2
(
#if 0
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(
out_10d_thread_desc,
out_10d_thread_desc,
p_out_thread,
p_out_thread,
out_10d_global_desc,
out_10d_global_desc,
...
@@ -453,8 +420,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
...
@@ -453,8 +420,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
out_10d_thread_desc.GetLengths(),
map_out_global2thread
);
map_out_global2thread,
// Number<OutThreadCopyDataPerWrite_W>{});
Number<OutThreadCopyDataPerWrite_W>{});
#else
threadwise_tensor_slice_copy_generic
(
out_10d_thread_desc
.
ReorderGivenNew2Old
(
map_out_global2thread
),
p_out_thread
,
make_zero_array
<
index_t
,
10
>
(),
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
SeqType
{});
#endif
});
});
}
}
};
};
src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp
View file @
8fcf3f1e
...
@@ -151,6 +151,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -151,6 +151,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// slice a tensor, and copy it into another tensor
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
// this copy operator already have blockwise offset built-in
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
#if 0
BlockwiseTensorSliceCopy_generic_v1<BlockSize,
BlockwiseTensorSliceCopy_generic_v1<BlockSize,
Float,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_global_desc),
...
@@ -164,15 +165,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
...
@@ -164,15 +165,26 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>(
WeiBlockCopyDataPerAccess_K>(
{0, k_block_data_on_global}, {0, 0});
{0, k_block_data_on_global}, {0, 0});
#else
// GEMM definition
Blockwise2dTensorCopy3
<
BlockSize
,
// c_mtx += transpose(a_mtx) * b_mtx
Float
,
// a_mtx[CPerBlock, KPerBlock] is in LDS
decltype
(
wei_c_k_global_desc
),
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
decltype
(
wei_c_k_block_desc
),
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
// register
WeiBlockCopyDataPerAccess_K
>
({
0
,
k_block_data_on_global
},
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
{
0
,
0
});
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
#endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[CPerBlock, KPerBlock] is in LDS
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_n1bn2_block_mtx_desc
=
constexpr
auto
b_c_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment