Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
631d9892
Commit
631d9892
authored
Jan 29, 2021
by
Jing Zhang
Browse files
clean code
parent
961556eb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
13 deletions
+12
-13
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
...sor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
+5
-6
driver/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+2
-2
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+5
-5
No files found.
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
View file @
631d9892
...
@@ -122,10 +122,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
...
@@ -122,10 +122,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
__device__
static
auto
buffer_vector_load
(
const
SrcData
*
p_src
,
const
SrcCoord
src_coord_begin
)
__device__
static
auto
buffer_vector_load
(
const
SrcData
*
p_src
,
const
SrcCoord
src_coord_begin
)
{
{
auto
src_offset
=
src_coord_begin
.
GetOffset
();
auto
src_offset
=
src_coord_begin
.
GetOffset
();
auto
r
=
GetRegBuffer
<
SrcData
,
SrcDataPerAccess
>
();
return
amd_buffer_load
<
SrcData
,
SrcDataPerAccess
>
(
p_src
,
src_offset
,
true
,
SrcDataRange
);
r
.
GetVector
(
Number
<
SrcDataPerAccess
>
{})(
Number
<
0
>
{})
=
amd_buffer_load
<
SrcData
,
SrcDataPerAccess
>
(
p_src
,
src_offset
,
true
,
SrcDataRange
);
return
r
;
}
}
template
<
typename
DstData
,
index_t
DstDataPerAccess
>
template
<
typename
DstData
,
index_t
DstDataPerAccess
>
...
@@ -187,8 +184,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5
...
@@ -187,8 +184,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5
// load data from src to the long-vector buffer
// load data from src to the long-vector buffer
const
auto
src_coord
=
mSrcSliceOrigin
+
to_multi_index
(
long_vector_data_begin_id
);
const
auto
src_coord
=
mSrcSliceOrigin
+
to_multi_index
(
long_vector_data_begin_id
);
auto
src_buff
=
buffer_vector_load
<
SrcDataPerRead
,
SrcDesc
::
GetElementSpace
()
>
(
auto
src_buff
=
GetRegBuffer
<
SrcData
,
SrcDataPerRead
>
();
p_src
,
src_coord
);
src_buff
.
GetVector
(
Number
<
SrcDataPerRead
>
{})(
Number
<
0
>
{})
=
buffer_vector_load
<
SrcDataPerRead
,
SrcDesc
::
GetElementSpace
()
>
(
p_src
,
src_coord
);
static_for
<
0
,
SrcDataPerRead
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcDataPerRead
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
vector_id
=
long_vector_data_begin_id
.
Modify
(
constexpr
auto
vector_id
=
long_vector_data_begin_id
.
Modify
(
...
...
driver/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
631d9892
...
@@ -115,7 +115,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
...
@@ -115,7 +115,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// B matrix Copy
// B matrix Copy
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmK
=
4
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmK
=
4
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmKPack
=
4
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmKPack
=
1
;
constexpr
index_t
GemmBBlockCopyThreadSliceLengths_GemmK
=
constexpr
index_t
GemmBBlockCopyThreadSliceLengths_GemmK
=
GemmKPerBlock
/
GemmBBlockCopyClusterLengths_GemmK
;
GemmKPerBlock
/
GemmBBlockCopyClusterLengths_GemmK
;
...
@@ -141,7 +141,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
...
@@ -141,7 +141,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
using
GemmBBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [GemmG, GemmK, GemmN, GemmKPack]
using
GemmBBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [GemmG, GemmK, GemmN, GemmKPack]
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPack
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPack
=
4
;
// gridwise GEMM
// gridwise GEMM
constexpr
auto
wkgrp_schd_order
=
NBlock1MBlock0
;
constexpr
auto
wkgrp_schd_order
=
NBlock1MBlock0
;
...
...
driver/src/conv_driver.cpp
View file @
631d9892
...
@@ -24,11 +24,11 @@ int main(int argc, char* argv[])
...
@@ -24,11 +24,11 @@ int main(int argc, char* argv[])
using
namespace
ck
;
using
namespace
ck
;
// 1x1, 56x56
// 1x1, 56x56
constexpr
index_t
N
=
4
;
constexpr
index_t
N
=
6
4
;
constexpr
index_t
C
=
32
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
2
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
2
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
32
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment