Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
b93d2e1b
Commit
b93d2e1b
authored
Apr 26, 2019
by
Chao Liu
Browse files
fix batch gemm asm bug
parent
46a0aec1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
25 deletions
+108
-25
src/include/amd_inline_asm.hip.hpp
src/include/amd_inline_asm.hip.hpp
+57
-1
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+14
-14
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
+8
-2
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
+29
-8
No files found.
src/include/amd_inline_asm.hip.hpp
View file @
b93d2e1b
...
@@ -201,7 +201,7 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -201,7 +201,7 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
if
(
offset
==
0
)
if
(
offset
==
0
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
offset:0
\n
\
ds_read_b128 %0, %1
\n
\
"
"
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
...
@@ -350,6 +350,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -350,6 +350,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
}
}
else
if
(
offset
==
2432
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2432
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
2560
)
else
if
(
offset
==
2560
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -358,6 +366,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -358,6 +366,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
}
}
else
if
(
offset
==
2688
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2688
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
2816
)
else
if
(
offset
==
2816
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -366,6 +382,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -366,6 +382,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
}
}
else
if
(
offset
==
2944
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:2944
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
3072
)
else
if
(
offset
==
3072
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -374,6 +398,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -374,6 +398,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
}
}
else
if
(
offset
==
3200
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3200
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
3328
)
else
if
(
offset
==
3328
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -382,6 +414,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -382,6 +414,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
}
}
else
if
(
offset
==
3456
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3456
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
3584
)
else
if
(
offset
==
3584
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -390,6 +430,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -390,6 +430,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
}
}
else
if
(
offset
==
3712
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3712
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
3840
)
else
if
(
offset
==
3840
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
@@ -398,6 +446,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
...
@@ -398,6 +446,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
:
"v"
(
__to_local
(
lds
)));
}
}
else
if
(
offset
==
3968
)
{
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:3968
\n
\
"
:
"=v"
(
r
)
:
"v"
(
__to_local
(
lds
)));
}
else
if
(
offset
==
4096
)
else
if
(
offset
==
4096
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
...
...
src/include/blockwise_batched_gemm.hip.hpp
View file @
b93d2e1b
...
@@ -293,8 +293,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -293,8 +293,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
...
@@ -344,24 +342,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -344,24 +342,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
&
p_b_block
[
b_block_mtx
.
Get1dIndex
(
0
,
NPerLevel1Cluster
)
+
mMyThreadOffsetB
]);
reg_a
[
1
]
=
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
&
p_a_block
[
a_block_mtx
.
Get1dIndex
(
0
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k
,
0
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k
,
0
)
+
mMyThreadOffsetB
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
&
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k
,
NPerLevel1Cluster
)
+
mMyThreadOffsetB
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
&
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
}
...
@@ -430,10 +430,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -430,10 +430,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
void
*
a_lds_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
a_lds_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_lds_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
void
*
b_lds_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
constexpr
index_t
a_lds_row_stride
=
sizeof
(
F
loat
)
*
M
;
constexpr
index_t
a_lds_row_stride
=
sizeof
(
f
loat
)
*
a_block_mtx
.
RowStride
()
;
constexpr
index_t
b_lds_row_stride
=
sizeof
(
F
loat
)
*
N
;
constexpr
index_t
b_lds_row_stride
=
sizeof
(
f
loat
)
*
b_block_mtx
.
RowStride
()
;
constexpr
index_t
a_lds_cluster_col_stride
=
sizeof
(
F
loat
)
*
MPerLevel1Cluster
;
constexpr
index_t
a_lds_cluster_col_stride
=
sizeof
(
f
loat
)
*
MPerLevel1Cluster
;
constexpr
index_t
b_lds_cluster_col_stride
=
sizeof
(
F
loat
)
*
NPerLevel1Cluster
;
constexpr
index_t
b_lds_cluster_col_stride
=
sizeof
(
f
loat
)
*
NPerLevel1Cluster
;
ds_read_b128
(
reg_a
[
0
],
a_lds_loc
,
0
);
ds_read_b128
(
reg_a
[
0
],
a_lds_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_lds_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_lds_loc
,
0
);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
View file @
b93d2e1b
...
@@ -213,7 +213,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -213,7 +213,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
#if
0
#if
1
const
Float
*
p_in_global_block_offset
=
const
Float
*
p_in_global_block_offset
=
p_in_global
+
p_in_global
+
in_c_h_w_n_global_desc
.
Get1dIndex
(
in_c_h_w_n_global_desc
.
Get1dIndex
(
...
@@ -241,7 +241,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -241,7 +241,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
__syncthreads
();
__syncthreads
();
#if 1
blockwise_batch_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
blockwise_batch_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#elif 0
blockwise_batch_gemm
.
Run_asm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#elif 1
blockwise_batch_gemm
.
Run_asm_v2
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#endif
__syncthreads
();
__syncthreads
();
}
}
...
@@ -277,7 +283,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -277,7 +283,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
blockwise_batch_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
blockwise_batch_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#elif 0
#elif 0
blockwise_batch_gemm
.
Run_asm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
blockwise_batch_gemm
.
Run_asm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#elif
0
#elif
1
blockwise_batch_gemm
.
Run_asm_v2
(
p_wei_block
,
p_in_block
,
p_out_thread
);
blockwise_batch_gemm
.
Run_asm_v2
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#endif
#endif
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
View file @
b93d2e1b
...
@@ -293,8 +293,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -293,8 +293,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_batch_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
#if 1
blockwise_batch_gemm
.
Run
#elif 0
blockwise_batch_gemm
.
Run_asm
#else
blockwise_batch_gemm
.
Run_asm_v2
#endif
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
...
@@ -321,8 +328,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -321,8 +328,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_batch_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
#if 1
blockwise_batch_gemm
.
Run
#elif 0
blockwise_batch_gemm
.
Run_asm
#else
blockwise_batch_gemm
.
Run_asm_v2
#endif
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
...
@@ -333,10 +347,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -333,10 +347,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_batch_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
#if 1
p_in_block_double
+
in_block_space
,
blockwise_batch_gemm
.
Run
p_out_thread
);
#elif 0
blockwise_batch_gemm
.
Run_asm
#else
blockwise_batch_gemm
.
Run_asm_v2
#endif
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment