Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
0d6aa311
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "b6571d2295bf4b4d5c6ae159b1d8b4dc36e300a5"
Commit
0d6aa311
authored
Apr 03, 2019
by
Jing Zhang
Browse files
inline asm
parent
753b98b5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
33 deletions
+38
-33
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+1
-1
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+27
-30
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+9
-1
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
0d6aa311
...
@@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
#elif
0
#elif
1
// 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer
// 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
...
...
src/include/blockwise_gemm.hip.hpp
View file @
0d6aa311
...
@@ -332,12 +332,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -332,12 +332,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
,
index_t
block_off
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
,
Accumulator
f_accum
)
const
Number
<
block_off
>
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -378,45 +377,43 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -378,45 +377,43 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
//auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
//auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
#pragma unroll
#pragma unroll
// loop over k
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
{
#if 0
auto
a_src_index
=
a_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetA
;
auto
b_src_index
=
b_block_mtx
.
Get1dIndex
(
k_begin
,
0
)
+
mMyThreadOffsetB
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
void
*
a_loc
=
(
void
*
)(
p_a_block
+
a_src_index
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
b_src_index
);
//asm volatile("\n \
//ds_read_b128 %0, %2 \n \
//ds_read_b128 %1, %2 offset:256\n \
//"
//: "=v"(reg_a[0]), "=v"(reg_a[1])
//: "v"(__to_local(a_loc))
//);
ds_read_b128(reg_a[0], a_loc, 0);
ds_read_b128(reg_a[0], a_loc, 0);
ds_read_b128(reg_a[1], a_loc, 256);
ds_read_b128(reg_a[1], a_loc, 256);
ds_read_b128(reg_b[0], b_loc, 0);
ds_read_b128(reg_b[0], b_loc, 0);
ds_read_b128(reg_b[1], b_loc, 128);
ds_read_b128(reg_b[1], b_loc, 128);
lgkmcnt(0);
lgkmcnt(0);
threadwise_gemm
(
a_thread_mtx
,
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
True
,
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
p_a_thread
,
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
b_thread_mtx
,
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
False
,
#else
p_b_thread
,
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k_begin
*
512
);
c_thread_mtx
,
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k_begin
*
256
);
False
,
ds_read_b128
(
reg_b
[
1
],
b_loc
,
128
+
k_begin
*
256
);
p_c_thread
,
ds_read_b128
(
reg_a
[
1
],
a_loc
,
256
+
k_begin
*
512
);
f_accum
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
#endif
}
}
}
}
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
0d6aa311
...
@@ -323,7 +323,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
...
@@ -323,7 +323,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
y
*
Wi
+
x
,
p_in_block
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
,
f_accum
,
Number
<
in_block_element_space
>
()
);
f_accum
);
}
}
}
}
}
}
...
...
src/include/threadwise_gemm.hip.hpp
View file @
0d6aa311
...
@@ -12,7 +12,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -12,7 +12,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
#if
0
#if
1
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
...
@@ -72,6 +72,7 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -72,6 +72,7 @@ __device__ void threadwise_gemm(MatrixA,
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
#if 1
for
(
index_t
i
=
0
;
i
<
M
;
i
+=
4
)
for
(
index_t
i
=
0
;
i
<
M
;
i
+=
4
)
{
{
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
...
@@ -88,6 +89,13 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -88,6 +89,13 @@ __device__ void threadwise_gemm(MatrixA,
outerProduct4x4
(
a_vec
[
0
],
b_vec
[
0
],
c_vec
[
0
],
c_vec
[
2
],
c_vec
[
4
],
c_vec
[
6
]);
outerProduct4x4
(
a_vec
[
0
],
b_vec
[
0
],
c_vec
[
0
],
c_vec
[
2
],
c_vec
[
4
],
c_vec
[
6
]);
}
}
}
}
#else
const
Float4
*
a_vec
=
(
const
Float4
*
)
p_a_thread
;
const
Float4
*
b_vec
=
(
const
Float4
*
)
p_b_thread
;
Float4
*
c_vec
=
(
Float4
*
)
p_c_thread
;
outerProduct8x8
(
a_vec
,
b_vec
,
c_vec
);
#endif
}
}
}
}
else
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment