Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d990eff6
Commit
d990eff6
authored
Apr 21, 2021
by
Chao Liu
Browse files
clean
parent
437c996a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
38 deletions
+37
-38
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+37
-38
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
d990eff6
...
@@ -529,8 +529,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -529,8 +529,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! only support 2x2 pipeline"
);
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A-sub, B-sub
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
constexpr
auto
a_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
...
@@ -557,83 +556,83 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -557,83 +556,83 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0
// read A_sub_0
a_thread_copy_
.
Run
(
BlockMatrixA
{},
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_mtx_desc_
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
// read B_sub_0
// read B_sub_0
b_thread_copy_
.
Run
(
BlockMatrixB
{},
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_mtx_desc_
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
// read B_sub_1
// read B_sub_1
b_thread_copy_
.
Run
(
BlockMatrixB
{},
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerLevel1Cluster
>
{}),
make_tuple
(
I0
,
Number
<
NPerLevel1Cluster
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_mtx_desc_
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
b_thread_buf
);
b_thread_buf
);
// read A_sub_1
// read A_sub_1
a_thread_copy_
.
Run
(
BlockMatrixA
{},
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
Number
<
0
>
{}
,
Number
<
MPerLevel1Cluster
>
{}),
make_tuple
(
I0
,
Number
<
MPerLevel1Cluster
>
{}),
a_block_buf
,
a_block_buf
,
a_thread_mtx_desc_
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
a_thread_buf
);
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
));
make_tuple
(
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerThreadSubC
>
{}));
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}));
// loop over rest of k
// loop over rest of k
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
// read A_sub_0
// read A_sub_0
a_thread_copy_
.
Run
(
BlockMatrixA
{},
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
k
,
Number
<
0
>
{}
),
make_tuple
(
k
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_mtx_desc_
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
0
>
{}
));
make_tuple
(
Number
<
MPerThreadSubC
>
{},
I0
));
// read B_sub_0
// read B_sub_0
b_thread_copy_
.
Run
(
BlockMatrixB
{},
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
k
,
Number
<
0
>
{}
),
make_tuple
(
k
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_mtx_desc_
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
...
@@ -642,7 +641,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -642,7 +641,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple
(
k
,
Number
<
NPerLevel1Cluster
>
{}),
make_tuple
(
k
,
Number
<
NPerLevel1Cluster
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_mtx_desc_
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
b_thread_buf
);
b_thread_buf
);
// read A_sub_1
// read A_sub_1
...
@@ -650,39 +649,39 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -650,39 +649,39 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple
(
k
,
Number
<
MPerLevel1Cluster
>
{}),
make_tuple
(
k
,
Number
<
MPerLevel1Cluster
>
{}),
a_block_buf
,
a_block_buf
,
a_thread_mtx_desc_
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
a_thread_buf
);
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
));
make_tuple
(
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerThreadSubC
>
{}));
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}));
});
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}
),
make_tuple
(
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
0
>
{}
));
make_tuple
(
Number
<
MPerThreadSubC
>
{},
I0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
Number
<
0
>
{}
,
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment