Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
760a234f
Commit
760a234f
authored
Dec 03, 2020
by
Chao Liu
Browse files
use StaticallyIndexedArray for buffer in threadwise copy, in order to get rid of alloca in IR
parent
70d06fa9
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
508 additions
and
1395 deletions
+508
-1395
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+35
-90
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
...sor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
+62
-546
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+69
-448
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+332
-299
driver/include/conv_common.hpp
driver/include/conv_common.hpp
+9
-11
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
760a234f
...
@@ -173,39 +173,41 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -173,39 +173,41 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// GEMM
// GEMM
#if 1
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v1
<
using
gridwise_gemm
=
BlockSize
,
GridwiseDynamicGemm_km_kn_mn_v1
<
BlockSize
,
Float
,
Float
,
AccFloat
,
AccFloat
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmMPerThread
,
GemmNPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmM
,
true
,
// move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
Sequence
<
2
,
3
,
0
,
1
>
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
3
,
// MoveSrcSliceWindow() to save addr computation
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
...
@@ -261,63 +263,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -261,63 +263,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_out_global
,
p_out_global
,
integral_constant
<
bool
,
false
>
{});
integral_constant
<
bool
,
false
>
{});
}
}
#else
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v2
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
);
#endif
}
}
};
};
...
...
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
View file @
760a234f
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
760a234f
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
760a234f
This diff is collapsed.
Click to expand it.
driver/include/conv_common.hpp
View file @
760a234f
...
@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
...
@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
}
}
template
<
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
constexpr
std
::
size_t
calculate_convolution_flops
(
InDesc
,
WeiDesc
,
OutDesc
)
constexpr
std
::
size_t
calculate_convolution_flops
(
const
InDesc
&
in_desc
,
const
WeiDesc
&
wei_desc
,
const
OutDesc
&
out_desc
)
{
{
using
namespace
ck
;
using
namespace
ck
;
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
expr
index_t
N
=
out_desc
.
GetLength
(
I0
);
const
index_t
N
=
out_desc
.
GetLength
(
I0
);
const
expr
index_t
K
=
out_desc
.
GetLength
(
I1
);
const
index_t
K
=
out_desc
.
GetLength
(
I1
);
const
expr
index_t
Ho
=
out_desc
.
GetLength
(
I2
);
const
index_t
Ho
=
out_desc
.
GetLength
(
I2
);
const
expr
index_t
Wo
=
out_desc
.
GetLength
(
I3
);
const
index_t
Wo
=
out_desc
.
GetLength
(
I3
);
const
expr
index_t
C
=
wei_desc
.
GetLength
(
I1
);
const
index_t
C
=
wei_desc
.
GetLength
(
I1
);
const
expr
index_t
Y
=
wei_desc
.
GetLength
(
I2
);
const
index_t
Y
=
wei_desc
.
GetLength
(
I2
);
const
expr
index_t
X
=
wei_desc
.
GetLength
(
I3
);
const
index_t
X
=
wei_desc
.
GetLength
(
I3
);
return
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
return
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
}
}
...
...
driver/src/conv_driver.cpp
View file @
760a234f
...
@@ -577,7 +577,7 @@ int main(int argc, char* argv[])
...
@@ -577,7 +577,7 @@ int main(int argc, char* argv[])
LeftPads{},
LeftPads{},
RightPads{},
RightPads{},
nrepeat);
nrepeat);
#elif
0
#elif
1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment