Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
49b926b6
Commit
49b926b6
authored
Jun 02, 2021
by
Chao Liu
Browse files
overhauling fwd-v4r4
parent
b8382727
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
211 additions
and
574 deletions
+211
-574
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
...osable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
+49
-66
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+108
-104
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+1
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+53
-403
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
View file @
49b926b6
...
@@ -26,21 +26,21 @@ template <index_t BlockSize,
...
@@ -26,21 +26,21 @@ template <index_t BlockSize,
index_t
M1N1ThreadClusterN10
,
index_t
M1N1ThreadClusterN10
,
index_t
M1N1ThreadClusterM11
,
index_t
M1N1ThreadClusterM11
,
index_t
M1N1ThreadClusterN11
,
index_t
M1N1ThreadClusterN11
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadSliceLengths_K_M
0_M1
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
0_M1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_M
,
index_t
ABlockTransferDstScalarPerVector_M
1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K_N
,
typename
BBlockTransferThreadSliceLengths_K_N
0_N1
,
typename
BBlockTransferThreadClusterLengths_K_N
,
typename
BBlockTransferThreadClusterLengths_K_N
0_N1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_N
,
index_t
BBlockTransferDstScalarPerVector_N
1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
...
@@ -69,48 +69,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -69,48 +69,48 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
// GEMM
// GEMM
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseDynamicGemm_km_kn_m
0m1n0n1
_v1r2
<
BlockSize
,
GridwiseDynamicGemm_km_kn_m
n
_v1r2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
AKMGridDesc
,
AKMGridDesc
,
BKNGridDesc
,
BKNGridDesc
,
CMNGridDesc
,
CMNGridDesc
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
M1PerThread
,
M1PerThread
,
N1PerThread
,
N1PerThread
,
KPerThread
,
KPerThread
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
,
M1N1ThreadClusterN11
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadSliceLengths_K_M
0_M1
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
0_M1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
ABlockTransferDstScalarPerVector_M
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K_N
,
BBlockTransferThreadSliceLengths_K_N
0_N1
,
BBlockTransferThreadClusterLengths_K_N
,
BBlockTransferThreadClusterLengths_K_N
0_N1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BBlockTransferDstScalarPerVector_N
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGridIteratorHacks
,
AGridIteratorHacks
,
BGridIteratorHacks
,
BGridIteratorHacks
,
CGridIteratorHacks
,
CGridIteratorHacks
,
AGridMoveSliceWindowIteratorHacks
,
AGridMoveSliceWindowIteratorHacks
,
BGridMoveSliceWindowIteratorHacks
>
;
BGridMoveSliceWindowIteratorHacks
>
;
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
...
@@ -118,8 +118,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -118,8 +118,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m_n_grid_desc
))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m_n_grid_desc
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting"
);
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting"
);
}
}
const
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
const
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
...
@@ -154,8 +153,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -154,8 +153,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
...
@@ -172,8 +169,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -172,8 +169,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
@@ -185,8 +180,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -185,8 +180,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
...
@@ -203,8 +196,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -203,8 +196,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
@@ -216,8 +207,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -216,8 +207,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
...
@@ -234,8 +223,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -234,8 +223,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
@@ -247,8 +234,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -247,8 +234,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
...
@@ -265,8 +250,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -265,8 +250,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
49b926b6
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/math.hpp
View file @
49b926b6
...
@@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
...
@@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
template
<
class
X
,
class
Y
>
template
<
class
X
,
class
Y
>
__host__
__device__
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
__host__
__device__
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
{
{
return
(
x
+
y
-
1
)
/
y
;
return
(
x
+
y
-
Number
<
1
>
{}
)
/
y
;
}
}
template
<
class
X
,
class
Y
>
template
<
class
X
,
class
Y
>
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
49b926b6
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment