Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f63f1636
Commit
f63f1636
authored
Jun 01, 2021
by
Chao Liu
Browse files
refactor
parent
51cdcee6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
65 deletions
+87
-65
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
...osable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
+34
-51
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+51
-12
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+2
-2
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
View file @
f63f1636
...
...
@@ -69,44 +69,8 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
a_k_m_grid_desc
.
GetLength
(
I0
);
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
Number
<
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
>
{};
constexpr
auto
N11
=
Number
<
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
>
{};
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_dynamic_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
c_m0_m10_m11_n0_n10_n11_grid_desc
);
// out_gemm_block_cluster_desc
const
auto
c_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
M
/
Number
<
MPerBlock
>
{},
N
/
Number
<
NPerBlock
>
{}));
using
CBlockClusterDesc
=
decltype
(
c_block_cluster_desc
);
// GEMM
using
g
ridwise
_g
emm
=
using
G
ridwise
G
emm
=
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
<
BlockSize
,
FloatAB
,
FloatAcc
,
...
...
@@ -114,8 +78,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
CGlobalMemoryDataOperation
,
AKMGridDesc
,
BKNGridDesc
,
CM0M10M11N0N10N11GridDesc
,
CBlockClusterDesc
,
CMNGridDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -151,6 +114,26 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
AGridMoveSliceWindowIteratorHacks
,
BGridMoveSliceWindowIteratorHacks
>
;
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
a_k_m_grid_desc
.
GetLength
(
I0
);
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
// c_m0_m10_m11_n0_n10_n11_grid_desc
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
GridwiseGemm
::
MakeCM0M10M11N0N10N11GridDescriptor
(
c_m_n_grid_desc
);
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
c_m0_m10_m11_n0_n10_n11_grid_desc
);
// c_block_cluster_adaptor
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
auto
GridSize
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
bool
has_main_k_block_loop
=
(
K
+
KPerBlock
)
/
(
2
*
KPerBlock
)
>
1
;
...
...
@@ -161,13 +144,13 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
g
ridwise
_g
emm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
G
ridwise
G
emm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockCluster
Desc
>
,
remove_reference_t
<
CBlockCluster
Adaptor
>
,
true
,
true
>
;
...
...
@@ -183,17 +166,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_
desc
);
c_block_cluster_
adaptor
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
g
ridwise
_g
emm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
G
ridwise
G
emm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockCluster
Desc
>
,
remove_reference_t
<
CBlockCluster
Adaptor
>
,
true
,
false
>
;
...
...
@@ -209,17 +192,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_
desc
);
c_block_cluster_
adaptor
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
g
ridwise
_g
emm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
G
ridwise
G
emm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockCluster
Desc
>
,
remove_reference_t
<
CBlockCluster
Adaptor
>
,
false
,
true
>
;
...
...
@@ -235,17 +218,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_
desc
);
c_block_cluster_
adaptor
);
}
else
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
g
ridwise
_g
emm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
G
ridwise
G
emm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockCluster
Desc
>
,
remove_reference_t
<
CBlockCluster
Adaptor
>
,
false
,
false
>
;
...
...
@@ -261,7 +244,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_
desc
);
c_block_cluster_
adaptor
);
}
return
ave_time
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
f63f1636
...
...
@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename
AKMGridDesc
,
typename
BKNGridDesc
,
typename
CM0M10M11N0N10N11GridDesc
,
typename
CBlockCluster
Desc
,
typename
CBlockCluster
Adaptor
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
...
...
@@ -31,7 +31,7 @@ __global__ void
const
AKMGridDesc
a_k_m_grid_desc
,
const
BKNGridDesc
b_k_n_grid_desc
,
const
CM0M10M11N0N10N11GridDesc
c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
CBlockCluster
Desc
c_block_cluster_desc
)
const
CBlockCluster
Adaptor
c_block_cluster_desc
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
@@ -57,8 +57,7 @@ template <index_t BlockSize,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
typename
AKMGridDesc
,
typename
BKNGridDesc
,
typename
CM0M10M11N0N10N11GridDesc
,
typename
CBlockClusterDesc
,
typename
CMNGridDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
...
...
@@ -163,15 +162,55 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return
b_k_n0_n1_block_clusterized_grid_desc
;
}
#if 0
__host__
__device__
static
constexpr
auto
MakeCM0M10M11N0N10N11GridDescriptor(const
BK
NGridDesc&
b_k
_n_grid_desc)
MakeCM0M10M11N0N10N11GridDescriptor
(
const
CM
NGridDesc
&
c_m
_n_grid_desc
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
Number
<
M1N1ThreadClusterM100
*
M1N1ThreadClusterM101
*
M1PerThread
>
{};
constexpr
auto
N11
=
Number
<
M1N1ThreadClusterN100
*
M1N1ThreadClusterN101
*
N1PerThread
>
{};
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
transform_dynamic_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_m0_m10_m11_n0_n10_n11_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
c_block_cluster_adaptor
=
make_cluster_descriptor_v2
(
make_tuple
(
M0
,
N0
));
return
c_block_cluster_adaptor
;
}
#endif
using
AKM0M1GridDesc
=
decltype
(
MakeAKM0M1GridDescriptor
(
AKMGridDesc
{}));
using
BKN0N1GridDesc
=
decltype
(
MakeBKN0N1GridDescriptor
(
BKNGridDesc
{}));
using
AKM0M1GridDesc
=
decltype
(
MakeAKM0M1GridDescriptor
(
AKMGridDesc
{}));
using
BKN0N1GridDesc
=
decltype
(
MakeBKN0N1GridDescriptor
(
BKNGridDesc
{}));
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
MakeCM0M10M11N0N10N11GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
@@ -181,7 +220,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const
AKMGridDesc
&
a_k_m_grid_desc
,
const
BKNGridDesc
&
b_k_n_grid_desc
,
const
CM0M10M11N0N10N11GridDesc
&
c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
CBlockCluster
Desc
&
c_block_cluster_desc
,
const
CBlockCluster
Adaptor
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
...
...
@@ -506,8 +545,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1
,
c_m10_n10_m11_n11_thread_tensor_lengths
[
I2
],
c_m10_n10_m11_n11_thread_tensor_lengths
[
I3
]
>
,
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// TODO:
CThreadTransferSrcDstAccessOrder
5
,
// TODO:
CThreadTransferSrcDstVectorDim
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
f63f1636
...
...
@@ -551,8 +551,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
wei_gemmk_gemmm_grid_iterator_hacks
),
decltype
(
in_gemmk_gemmn_grid_iterator_hacks
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment