Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4ef5865b
Commit
4ef5865b
authored
Mar 17, 2021
by
root
Browse files
clean code
parent
caa91db0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
62 deletions
+44
-62
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+7
-11
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+37
-51
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
4ef5865b
...
@@ -19,8 +19,6 @@ template <index_t BlockSize,
...
@@ -19,8 +19,6 @@ template <index_t BlockSize,
index_t
HPerThread
,
index_t
HPerThread
,
index_t
WPerThread
,
index_t
WPerThread
,
index_t
CYXPerThreadLoop
,
index_t
CYXPerThreadLoop
,
index_t
HThreadCluster
,
index_t
WThreadCluster
,
index_t
ThreadGemmADataPerRead_K
,
index_t
ThreadGemmADataPerRead_K
,
index_t
ThreadGemmBDataPerRead_W
>
index_t
ThreadGemmBDataPerRead_W
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v3
struct
BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -46,11 +44,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -46,11 +44,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
// constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
// MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert
(
BlockSize
==
HThreadCluster
*
WThreadCluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
"wrong! K dimension not consistent
\n
"
);
...
@@ -59,10 +52,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -59,10 +52,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
static_assert
(
static_assert
(
K
%
KPerThread
==
0
&&
H
%
HPerThread
==
0
&&
W
%
WPerThread
==
0
,
K
%
(
KPerThread
)
==
0
&&
"wrong! Cannot evenly divide work among
\n
"
);
(
N
*
H
*
W
)
%
(
HPerThread
*
WPerThread
*
HThreadCluster
*
WThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
constexpr
auto
HThreadCluster
=
H
/
HPerThread
;
constexpr
auto
WThreadCluster
=
W
/
WPerThread
;
static_assert
(
BlockSize
==
HThreadCluster
*
WThreadCluster
,
"wrong! wrong blocksize
\n
"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
4ef5865b
...
@@ -117,23 +117,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -117,23 +117,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const
index_t
h_block_work_id
=
hw_block_work_id
/
w_block_work_num
;
const
index_t
h_block_work_id
=
hw_block_work_id
/
w_block_work_num
;
const
index_t
w_block_work_id
=
hw_block_work_id
-
h_block_work_id
*
w_block_work_num
;
const
index_t
w_block_work_id
=
hw_block_work_id
-
h_block_work_id
*
w_block_work_num
;
constexpr
auto
h_num_threads
=
HPerBlock
/
HPerThread
;
constexpr
auto
w_num_threads
=
WPerBlock
/
WPerThread
;
static_assert
(
KPerBlock
==
KPerThread
,
""
);
static_assert
(
KPerBlock
==
KPerThread
,
""
);
const
auto
k_thread_id
=
0
;
const
auto
h_thread_id
=
get_thread_local_1d_id
()
/
w_num_threads
;
const
auto
w_thread_id
=
get_thread_local_1d_id
()
%
w_num_threads
;
const
index_t
k_block_data_on_global
=
k_block_work_id
*
KPerBlock
;
const
index_t
h_block_data_on_global
=
h_block_work_id
*
HPerBlock
;
const
index_t
w_block_data_on_global
=
w_block_work_id
*
WPerBlock
;
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
const
index_t
h_thread_data_on_global
=
h_block_data_on_global
+
h_thread_id
*
HPerThread
;
const
index_t
w_thread_data_on_global
=
w_block_data_on_global
+
w_thread_id
*
WPerThread
;
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
KPerThread
>
{});
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
KPerThread
>
{});
...
@@ -149,6 +134,39 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -149,6 +134,39 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
CYXPerBlock
>
{},
Number
<
1
>
{},
Number
<
HPerBlock
>
{},
Number
<
WPerBlock
>
{}));
Number
<
CYXPerBlock
>
{},
Number
<
1
>
{},
Number
<
HPerBlock
>
{},
Number
<
WPerBlock
>
{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_cyx_k_block_desc
),
decltype
(
b_cyx_n_h_w_block_desc
),
decltype
(
c_k_n_h_w_thread_desc
),
KPerThread
,
// KPerThreadSubC
HPerThread
,
// HPerThreadSubC
WPerThread
,
// WPerThreadSubC
CYXPerThread
,
// CYXPerThreadLoop
1
,
// ThreadGemmADataPerRead_K
1
// ThreadGemmBDataPerRead_W
>
{};
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
auto
k_thread_id
=
c_thread_mtx_index
.
k
;
const
auto
h_thread_id
=
c_thread_mtx_index
.
h
;
const
auto
w_thread_id
=
c_thread_mtx_index
.
w
;
const
index_t
k_block_data_on_global
=
k_block_work_id
*
KPerBlock
;
const
index_t
h_block_data_on_global
=
h_block_work_id
*
HPerBlock
;
const
index_t
w_block_data_on_global
=
w_block_work_id
*
WPerBlock
;
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
const
index_t
h_thread_data_on_global
=
h_block_data_on_global
+
h_thread_id
*
HPerThread
;
const
index_t
w_thread_data_on_global
=
w_block_data_on_global
+
w_thread_id
*
WPerThread
;
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
...
@@ -182,7 +200,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -182,7 +200,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
CYXPerBlock
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
CYXPerBlock
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
T
hreadwise
TensorSliceT
ransfer
B
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
auto
b_t
hreadwise
_t
ransfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
Float
,
Float
,
Float
,
Float
,
decltype
(
b_cyx_n_h_w_global_desc
),
decltype
(
b_cyx_n_h_w_global_desc
),
...
@@ -195,34 +213,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -195,34 +213,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
1
,
1
,
true
>
;
true
>
(
ThreadwiseTensorSliceTransferB
b_threadwise_transfer
(
b_cyx_n_h_w_global_desc
,
b_cyx_n_h_w_global_desc
,
make_multi_index
(
0
,
0
,
h_thread_data_on_global
,
w_thread_data_on_global
));
make_multi_index
(
k_thread_data_on_global
,
0
,
h_thread_data_on_global
,
w_thread_data_on_global
));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
#if 1
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_cyx_k_block_desc
),
decltype
(
b_cyx_n_h_w_block_desc
),
decltype
(
c_k_n_h_w_thread_desc
),
KPerThread
,
// KPerThreadSubC
HPerThread
,
// HPerThreadSubC
WPerThread
,
// WPerThreadSubC
CYXPerThread
,
// CYXPerThreadLoop
h_num_threads
,
// HThreadCluster
w_num_threads
,
// WThreadCluster
1
,
// ThreadGemmADataPerRead_K
1
// ThreadGemmBDataPerRead_W
>
{};
#endif
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
...
@@ -267,14 +261,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -267,14 +261,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
b_cyx_n_h_w_global_iterator_hacks
);
b_cyx_n_h_w_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_double
);
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_double
);
#if 0
__syncthreads();
p_c_thread[0] += p_b_thread_double[0] + p_b_thread_double[1] + p_b_thread_double[2];
p_c_thread[0] += p_b_thread_double[3] + p_b_thread_double[4] + p_b_thread_double[5];
p_c_thread[0] += p_b_thread_double[6] + p_b_thread_double[7] + p_b_thread_double[8];
#endif
}
}
#if 1
#if 1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment