Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
269adde8
Commit
269adde8
authored
Mar 15, 2021
by
root
Browse files
fixed a bug
parent
a46a17fb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
34 deletions
+43
-34
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+5
-9
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+13
-6
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
+17
-11
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+6
-6
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+2
-2
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
269adde8
...
@@ -137,15 +137,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -137,15 +137,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
CYXPerThreadLoop
>
{},
Number
<
KPerThread
>
{}));
make_tuple
(
Number
<
CYXPerThreadLoop
>
{},
Number
<
KPerThread
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
// make_tuple(Number<CYXPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{},
Number
<
CYXPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
// Number<WPerThread>{}));
make_tuple
(
Number
<
CYXPerThreadLoop
>
{},
Number
<
1
>
{}));
constexpr
auto
c_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
constexpr
auto
c_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
// make_tuple(Number<KPerThread>{}, Number<1>{},
// Number<HPerThread>{}, Number<WPerThread>{}));
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
269adde8
...
@@ -95,11 +95,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -95,11 +95,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// divide block work by [M, N]
// divide block work by [M, N]
#if 1
#if 1
const
auto
m_block_work_num
=
K
/
Number
<
KPerBlock
>
{};
const
auto
m_block_work_num
=
K
/
Number
<
KPerBlock
>
{};
const
auto
hw_block_work_num
=
(
N
*
H
*
W
)
/
(
Number
<
HPerBlock
>
{}
*
Number
<
WPerBlock
>
{});
const
auto
h_block_work_num
=
H
/
Number
<
HPerBlock
>
{};
const
auto
w_block_work_num
=
W
/
Number
<
WPerBlock
>
{};
const
auto
hw_block_work_num
=
h_block_work_num
*
w_block_work_num
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
hw_block_work_num
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
hw_block_work_num
;
const
index_t
hw_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
hw_block_work_num
;
const
index_t
hw_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
hw_block_work_num
;
const
index_t
h_block_work_id
=
hw_block_work_id
/
w_block_work_num
;
const
index_t
w_block_work_id
=
hw_block_work_id
-
h_block_work_id
*
w_block_work_num
;
constexpr
auto
h_num_threads
=
HPerBlock
/
HPerThread
;
constexpr
auto
h_num_threads
=
HPerBlock
/
HPerThread
;
constexpr
auto
w_num_threads
=
WPerBlock
/
WPerThread
;
constexpr
auto
w_num_threads
=
WPerBlock
/
WPerThread
;
...
@@ -119,8 +124,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -119,8 +124,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const
index_t
m_block_data_on_global
=
k_block_work_id
*
KPerBlock
;
const
index_t
m_block_data_on_global
=
k_block_work_id
*
KPerBlock
;
const
index_t
h_block_data_on_global
=
h
w
_block_work_id
*
HPerBlock
;
const
index_t
h_block_data_on_global
=
h_block_work_id
*
HPerBlock
;
const
index_t
w_block_data_on_global
=
h
w_block_work_id
*
WPerBlock
;
const
index_t
w_block_data_on_global
=
w_block_work_id
*
WPerBlock
;
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
constexpr
auto
max_lds_align
=
...
@@ -187,8 +192,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -187,8 +192,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
ThreadwiseTensorSliceTransferB
b_threadwise_transfer
(
ThreadwiseTensorSliceTransferB
b_threadwise_transfer
(
b_cyx_n_h_w_global_desc
,
b_cyx_n_h_w_global_desc
,
make_multi_index
(
make_multi_index
(
0
,
0
,
0
,
h_block_data_on_global
+
h_thread_id
,
w_block_data_on_global
+
w_thread_id
));
0
,
h_block_data_on_global
+
h_thread_id
*
HPerThread
,
w_block_data_on_global
+
w_thread_id
*
WPerThread
));
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
...
@@ -426,7 +433,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -426,7 +433,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Float
,
Float
,
decltype
(
c_k_n_h_w_thread_desc
),
decltype
(
c_k_n_h_w_thread_desc
),
decltype
(
c_k_n_h_w_global_desc
),
decltype
(
c_k_n_h_w_global_desc
),
Sequence
<
KPerThread
,
1
,
1
,
1
>
,
Sequence
<
KPerThread
,
1
,
HPerThread
,
WPerThread
>
,
Sequence
<
3
,
2
,
0
,
1
>
,
// CThreadTransferSrcDstAccessOrder
Sequence
<
3
,
2
,
0
,
1
>
,
// CThreadTransferSrcDstAccessOrder
3
,
// CThreadTransferSrcDstVectorDim
3
,
// CThreadTransferSrcDstVectorDim
1
,
// CThreadTransferDstScalarPerVector,
1
,
// CThreadTransferDstScalarPerVector,
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
View file @
269adde8
...
@@ -47,20 +47,26 @@ struct ThreadwiseGemm_km_kn_mn_v3
...
@@ -47,20 +47,26 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
H
=
BDesc
{}.
GetLength
(
I2
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
W
=
BDesc
{}.
GetLength
(
I3
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
constexpr
auto
CYX
=
ADesc
{}.
GetLength
(
I0
);
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I1
);
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
static_for
<
0
,
CYX
,
1
>
{}([
&
](
auto
e
)
{
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
n
));
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
n
));
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
w
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
e
,
k
));
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
h
,
w
));
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
h
,
w
));
p_c
[
c_offset
]
+=
p_c
[
c_offset
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
});
});
});
});
});
});
});
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
269adde8
...
@@ -72,21 +72,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
...
@@ -72,21 +72,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HPerBlock
=
8
;
constexpr
index_t
HPerBlock
=
8
;
constexpr
index_t
WPerBlock
=
8
;
constexpr
index_t
WPerBlock
=
16
;
constexpr
index_t
CYXPerBlock
=
4
*
3
*
3
;
constexpr
index_t
CYXPerBlock
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
HPerThread
=
1
;
constexpr
index_t
HPerThread
=
1
;
constexpr
index_t
WPerThread
=
1
;
constexpr
index_t
WPerThread
=
2
;
constexpr
index_t
CYXPerThread
=
4
*
3
*
3
;
constexpr
index_t
CYXPerThread
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
9
,
1
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
36
,
1
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
...
...
driver/src/conv_driver.cpp
View file @
269adde8
...
@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
...
@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
#elif 1
#elif 1
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
HI
=
1080
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
WI
=
1920
;
constexpr
index_t
K
=
16
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment