Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
2058bec8
Commit
2058bec8
authored
Mar 28, 2019
by
Jing Zhang
Browse files
fused functions
parent
766b0a9e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
4 deletions
+55
-4
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+48
-2
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+6
-1
No files found.
src/include/blockwise_gemm.hip.hpp
View file @
2058bec8
...
...
@@ -379,8 +379,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// copy A-sub to form A
#if 0
#pragma unroll
// MRepeat = 2
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
...
...
@@ -391,9 +393,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
}
#else
{
auto
src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
dst_index
=
a_thread_sub_mtx
.
Get1dIndex
(
0
,
0
);
#pragma unroll
const
float4
*
loc
=
(
const
float4
*
)(
p_a_block
+
src_index
);
float4
*
reg
=
(
float4
*
)(
p_a_thread
+
dst_index
);
reg
[
0
]
=
loc
[
0
];
reg
[
MPerThreadSubC
/
4
]
=
loc
[
MPerLevel1Cluster
/
4
];
}
#endif
#if 0
// copy B-sub to form B
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
...
...
@@ -404,8 +419,21 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
}
#else
{
auto
src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
auto
dst_index
=
b_thread_sub_mtx
.
Get1dIndex
(
0
,
0
);
const
float4
*
loc
=
(
const
float4
*
)(
p_b_block
+
src_index
);
float4
*
reg
=
(
float4
*
)(
p_b_thread
+
dst_index
);
reg
[
0
]
=
loc
[
0
];
reg
[
NPerThreadSubC
/
4
]
=
loc
[
NPerLevel1Cluster
/
4
];
}
#endif
// C = A * B
#if 0
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
...
...
@@ -416,6 +444,24 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False,
p_c_thread,
f_accum);
#else
for
(
index_t
k
=
0
;
k
<
1
;
++
k
)
{
// M = 8
for
(
index_t
i
=
0
;
i
<
8
;
++
i
)
{
// N = 8
for
(
index_t
j
=
0
;
j
<
8
;
++
j
)
{
const
index_t
aindex
=
a_thread_sub_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
bindex
=
b_thread_sub_mtx
.
Get1dIndex
(
k
,
j
);
const
index_t
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
j
);
p_c_thread
[
cindex
]
+=
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
];
}
}
}
#endif
}
}
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
2058bec8
...
...
@@ -236,7 +236,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if
0
#if
1
blockwise_gemm
.
Run
#elif 0
blockwise_gemm
.
Run_asm
...
...
src/include/threadwise_gemm.hip.hpp
View file @
2058bec8
...
...
@@ -10,9 +10,11 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
#if 0
#if 1
//NRow = 1
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
//NCol = 4
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
j
);
...
...
@@ -76,10 +78,13 @@ __device__ void threadwise_gemm(MatrixA,
constexpr
index_t
N
=
c_mtx
.
NCol
();
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
// K = 1
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
// M = 8
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
{
// N = 8
for
(
index_t
j
=
0
;
j
<
N
;
++
j
)
{
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment