Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c0ffe379
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "b871670b5ae29aaa6cad1b2d4e004882f716c466"
Commit
c0ffe379
authored
May 18, 2021
by
Jing Zhang
Browse files
add 2x2 pipeline
parent
40016f20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
83 additions
and
19 deletions
+83
-19
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+83
-19
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
c0ffe379
...
@@ -36,9 +36,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -36,9 +36,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
// static constexpr index_t MPerBlock = M0 * M1; // A is transposed
// static constexpr index_t NPerBlock = N0 * N1;
static
constexpr
index_t
MWaves
=
M1
/
MPerWave
;
static
constexpr
index_t
MWaves
=
M1
/
MPerWave
;
static
constexpr
index_t
NWaves
=
N1
/
NPerWave
;
static
constexpr
index_t
NWaves
=
N1
/
NPerWave
;
...
@@ -101,9 +98,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -101,9 +98,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
}
}
}
}
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
const
index_t
m_repeat_id
,
__device__
static
CIndex
const
index_t
n_repeat_id
,
CalculateCThreadOriginDataIndex
(
const
index_t
m0
,
const
index_t
n0
,
const
index_t
blk_i
)
const
index_t
blk_i
)
{
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
...
@@ -113,8 +109,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -113,8 +109,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
row
=
m
_repeat_id
*
M1
+
waveId_m
*
MPerWave
+
thread_mtx_on_blk
.
row
;
const
index_t
row
=
m
0
*
M1
+
waveId_m
*
MPerWave
+
thread_mtx_on_blk
.
row
;
const
index_t
col
=
n
_repeat_id
*
N1
+
waveId_n
*
NPerWave
+
thread_mtx_on_blk
.
col
;
const
index_t
col
=
n
0
*
N1
+
waveId_n
*
NPerWave
+
thread_mtx_on_blk
.
col
;
return
CIndex
{
row
,
col
};
return
CIndex
{
row
,
col
};
}
}
...
@@ -148,7 +144,54 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -148,7 +144,54 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
static_for
<
0
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
static_for
<
KPerWave
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
...
@@ -156,6 +199,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -156,6 +199,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
...
@@ -163,12 +214,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -163,12 +214,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
1
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
make_tuple
(
k
,
I1
,
I0
),
b_block_buf
,
b_block_buf
,
...
@@ -176,12 +229,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -176,12 +229,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple
(
I0
,
I1
,
I0
),
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
b_thread_buf
);
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
// read A_sub_1
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
a_thread_copy_
.
Run
(
ABlockDesc
{},
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
make_tuple
(
k
,
I1
,
I0
),
a_block_buf
,
a_block_buf
,
...
@@ -189,18 +237,34 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -189,18 +237,34 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple
(
I0
,
I1
,
I0
),
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
0
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
});
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
}
}
private:
private:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment