Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b7c1259f
Commit
b7c1259f
authored
Sep 06, 2021
by
ltqin
Browse files
init ok
parent
0acd3ebe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
70 deletions
+80
-70
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+45
-39
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
+33
-29
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+2
-2
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
b7c1259f
...
...
@@ -366,18 +366,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
b_grid_size
=
CalculateGridSize
(
M
,
N
);
const
auto
nB
atch
=
get_block_1d_id
()
/
b_grid_size
;
const
auto
blockid_in_batch
=
get_block_1d_id
()
%
b_grid_size
;
const
auto
k_b
atch
_id
=
get_block_1d_id
()
/
b_grid_size
;
const
auto
block
_
id_in_batch
=
get_block_1d_id
()
%
b_grid_size
;
if
(
get_block_1d_id
()
==
2000
)
printf
(
"grid size: %d, Batch: %d block_id: %d
k0: %d
\n
"
,
printf
(
"grid size: %d,
k0: %d, blockid: %d, threadid %d,
Batch: %d block_id: %d
\n
"
,
b_grid_size
,
nBatch
,
blockid_in_batch
,
K0
);
K0
,
get_block_1d_id
(),
get_thread_local_1d_id
(),
k_batch_id
,
block_id_in_batch
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
blockid_in_batch
));
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
block
_
id_in_batch
));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
...
...
@@ -391,65 +393,69 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_b_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
1
,
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
1
,
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_
b_
k0_m_k1_grid_desc
),
decltype
(
a_
b_
k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
2
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_k0_m_k1_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
));
true
>
(
a_
b_
k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_
b_
k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_
b_
k0_n_k1_grid_desc
),
decltype
(
b_
b_
k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
2
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k0_n_k1_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
));
true
>
(
b_
b_
k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_
b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -490,8 +496,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
...
...
@@ -509,11 +515,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_
b_
k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_
b_
k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_
b_
k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_
b_
k0_n_k1_block_desc
,
b_block_buf
);
}
// main body
...
...
@@ -521,25 +527,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_
b_
k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_
b_
k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_
b_
k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_
b_
k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_
b_
k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_
b_
k0_n_k1_block_desc
,
b_block_buf
);
k_block_data_begin
+=
KPerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
View file @
b7c1259f
...
...
@@ -62,21 +62,21 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
using
GemmABlockTransferThreadSliceLengths_
GemmB_
GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_
GemmB_
GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
64
,
1
>
;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
using
GemmBBlockTransferThreadSliceLengths_
GemmB_
GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_
GemmB_
GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
KBatch
=
96
;
constexpr
index_t
KBatch
=
2
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -123,20 +123,24 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
@@ -157,10 +161,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
1
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -181,19 +185,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferThreadSliceLengths_
GemmB_
GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_
GemmB_
GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
3
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmBBlockTransferThreadSliceLengths_
GemmB_
GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_
GemmB_
GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
3
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
b7c1259f
...
...
@@ -233,6 +233,8 @@ int main(int argc, char* argv[])
in_right_pads_dev
);
};
// set zero to wei_device
wei_device
.
GenerateTensorValue
(
GeneratorTensor_0
{},
num_thread
);
#if USE_CONV_WRW_V4R4R2_XDL_NCHW
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R2XDLNCHW
)
{
...
...
@@ -267,8 +269,6 @@ int main(int argc, char* argv[])
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
// set zero to wei_device
wei_device
.
GenerateTensorValue
(
GeneratorTensor_0
{},
num_thread
);
const
auto
tmp
=
f_make_for_device_nchw
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment