Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
46a0aec1
"examples/vscode:/vscode.git/clone" did not exist on "f06e4e5579b2c43a35a1678dea9b9ec3d66d365a"
Commit
46a0aec1
authored
Apr 26, 2019
by
Chao Liu
Browse files
trying gemm asm
parent
2603bb0f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
106 additions
and
4 deletions
+106
-4
driver/driver.hip.cpp
driver/driver.hip.cpp
+4
-1
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+96
-3
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
+6
-0
No files found.
driver/driver.hip.cpp
View file @
46a0aec1
...
@@ -580,13 +580,16 @@ int main(int argc, char* argv[])
...
@@ -580,13 +580,16 @@ int main(int argc, char* argv[])
#if 0
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif
0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#elif 0
#elif 0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_3
{},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_3
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 1
#elif 1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#elif
1
#elif
0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
1
,
5
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
1
,
5
},
num_thread
);
auto
gen_wei
=
[](
auto
...
is
)
{
auto
gen_wei
=
[](
auto
...
is
)
{
...
...
src/include/blockwise_batched_gemm.hip.hpp
View file @
46a0aec1
...
@@ -289,9 +289,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -289,9 +289,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
FloatC
*
__restrict__
p_c_thread
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
...
@@ -371,6 +368,102 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -371,6 +368,102 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_asm_v2
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>::
value
&&
is_same
<
FloatB
,
float
>::
value
&&
is_same
<
FloatC
,
float
>::
value
,
"Run_asm only deal with float
\n
"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_asm cannot deal with this GEMM shape yet
\n
"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_asm only do float4 read
\n
"
);
static_assert
(
BlockMatrixStrideA
==
0
&&
BatchPerThread
==
1
,
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now
\n
"
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_lds_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_lds_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
constexpr
index_t
a_lds_row_stride
=
sizeof
(
Float
)
*
M
;
constexpr
index_t
b_lds_row_stride
=
sizeof
(
Float
)
*
N
;
constexpr
index_t
a_lds_cluster_col_stride
=
sizeof
(
Float
)
*
MPerLevel1Cluster
;
constexpr
index_t
b_lds_cluster_col_stride
=
sizeof
(
Float
)
*
NPerLevel1Cluster
;
ds_read_b128
(
reg_a
[
0
],
a_lds_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_lds_loc
,
0
);
ds_read_b128
(
reg_b
[
1
],
b_lds_loc
,
b_lds_cluster_col_stride
);
ds_read_b128
(
reg_a
[
1
],
a_lds_loc
,
a_lds_cluster_col_stride
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
ds_read_b128
(
reg_a
[
0
],
a_lds_loc
,
k
*
a_lds_row_stride
);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_lds_loc
,
k
*
b_lds_row_stride
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
ds_read_b128
(
reg_b
[
1
],
b_lds_loc
,
b_lds_cluster_col_stride
+
k
*
b_lds_row_stride
);
ds_read_b128
(
reg_a
[
1
],
a_lds_loc
,
a_lds_cluster_col_stride
+
k
*
a_lds_row_stride
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
lgkmcnt
(
0
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
#endif
#endif
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
View file @
46a0aec1
...
@@ -273,7 +273,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -273,7 +273,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
__syncthreads
();
__syncthreads
();
#if 1
blockwise_batch_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
blockwise_batch_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#elif 0
blockwise_batch_gemm
.
Run_asm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#elif 0
blockwise_batch_gemm
.
Run_asm_v2
(
p_wei_block
,
p_in_block
,
p_out_thread
);
#endif
__syncthreads
();
__syncthreads
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment