Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
d6d9a8e4
Commit
d6d9a8e4
authored
Mar 28, 2019
by
Chao Liu
Browse files
Jing's ds_read inline asm
parent
766b0a9e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
126 additions
and
25 deletions
+126
-25
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+32
-2
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+68
-5
src/include/common.hip.hpp
src/include/common.hip.hpp
+2
-0
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...mm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+8
-8
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+15
-9
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
d6d9a8e4
...
...
@@ -190,8 +190,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
256
;
#elif
1
// 1x1, 14x14, Vega
1
0
#elif
0
// 1x1, 14x14, Vega
2
0
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
...
...
@@ -219,6 +219,36 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 1x1, 14x14, Vega 20, hack CPerBlock = 1
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#endif
...
...
src/include/blockwise_gemm.hip.hpp
View file @
d6d9a8e4
...
...
@@ -420,9 +420,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
__device__
void
Run_asm
(
const
FloatA
*
const
__restrict__
p_a_block
,
const
FloatB
*
const
__restrict__
p_b_block
,
FloatC
*
const
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
...
...
@@ -462,11 +462,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
MRepeat
==
2
&&
NRepeat
==
2
&&
KPerThreadLoop
==
1
&&
K
==
1
,
"asm is not for this mtx shape"
);
const
FloatA
*
const
p_a_block_thread_offset
=
p_a_block
+
mMyThreadOffsetA
;
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
//#pragma unroll
#if 0
#pragma unroll
// copy A-sub to form A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
...
...
@@ -475,9 +482,65 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx.NCol(
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
}
#elif
1
// this produce right result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
))))
:
"v"
(
__to_local
(
(
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
))));
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
))));
#elif 0
// this produce wrong result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %2
\n
\
ds_read_b128 %1, %3
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
)))),
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
(
(
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
))),
"v"
(
__to_local
((
void
*
)(
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
))));
#elif 1
// this produce wrong result
using
vectorA_t
=
typename
vector_type
<
FloatA
,
4
>::
MemoryType
;
// this is float4*
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
0
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block_thread_offset
))));
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1 offset:16
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vectorA_t
*>
(
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
MPerThreadSubC
))))
:
"v"
(
__to_local
((
void
*
)(
p_a_block_thread_offset
))));
#endif
//#pragma unroll
// copy B-sub to form B
...
...
src/include/common.hip.hpp
View file @
d6d9a8e4
...
...
@@ -5,6 +5,8 @@
#include "Array.hip.hpp"
#include "functional.hip.hpp"
extern
"C"
__attribute__
((
address_space
(
3
)))
void
*
__to_local
(
void
*
p
)[[
hc
]];
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
d6d9a8e4
...
...
@@ -238,7 +238,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if 0
blockwise_gemm.Run
#elif
0
#elif
1
blockwise_gemm
.
Run_asm
#elif 1
blockwise_gemm
.
Run_RegisterDoubleBuffer
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
d6d9a8e4
...
...
@@ -289,10 +289,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#else
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
f_accum
);
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
f_accum
);
}
}
...
...
@@ -319,10 +319,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#else
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
f_accum
);
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
f_accum
);
}
}
}
...
...
src/include/threadwise_gemm.hip.hpp
View file @
d6d9a8e4
...
...
@@ -10,7 +10,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
#if
0
#if
1
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
...
...
@@ -21,7 +21,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
p_dst
[
dst_index
]
=
p_src
[
src_index
];
}
}
#elif
1
#elif
0
static_assert
(
NCol
==
4
,
"only for NCol == 4"
);
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
...
...
@@ -31,15 +31,21 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
0
);
#if 1
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
#if 0
*(reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
#elif
0
asm
volatile
(
"
\n
\
ds_read2_b64 %0, %1 offset1:1
\n
\
s_waitcnt lgkmcnt(0)"
:
"=v"
(
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])))
:
"v"
(
__to_local
((
void
*
)(
&
p_src
[
src_index
]))));
#elif 1
asm
volatile
(
"
\n
\
ds_read_b128 %0, %1
, offset:0
\n
\
"
:
"=v"
(
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
)))
:
"v"
(
(
uint32_t
)(
p_src
+
src_index
)));
ds_read_b128 %0, %1
\n
\
s_waitcnt lgkmcnt(0)
"
:
"=v"
(
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
]
)))
:
"v"
(
__to_local
((
void
*
)(
&
p_src
[
src_index
])
)));
#endif
}
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment