Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
263c5e41
Commit
263c5e41
authored
Jun 01, 2021
by
Chao Liu
Browse files
refactor
parent
f63f1636
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
18 deletions
+36
-18
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
...osable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
+7
-9
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
...sor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
+2
-2
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+21
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+6
-7
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
View file @
263c5e41
...
@@ -66,8 +66,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -66,8 +66,6 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
// GEMM
// GEMM
using
GridwiseGemm
=
using
GridwiseGemm
=
...
@@ -134,11 +132,11 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -134,11 +132,11 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
auto
G
rid
S
ize
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
index_t
g
rid
_s
ize
=
GridwiseGemm
::
CalculateGridSize
(
M
,
N
);
const
bool
has_main_k_block_loop
=
(
K
+
KPerBlock
)
/
(
2
*
KPerBlock
)
>
1
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
)
;
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K
)
;
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -156,7 +154,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -156,7 +154,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
G
rid
S
ize
),
dim3
(
g
rid
_s
ize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
...
@@ -182,7 +180,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -182,7 +180,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
G
rid
S
ize
),
dim3
(
g
rid
_s
ize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
...
@@ -208,7 +206,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -208,7 +206,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
G
rid
S
ize
),
dim3
(
g
rid
_s
ize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
...
@@ -234,7 +232,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -234,7 +232,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
G
rid
S
ize
),
dim3
(
g
rid
_s
ize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
...
...
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
View file @
263c5e41
...
@@ -31,8 +31,8 @@ template <index_t BlockSize,
...
@@ -31,8 +31,8 @@ template <index_t BlockSize,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
index_t
ThreadTransferDstResetCoordinateAfterRun
>
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseDynamicTensorSliceTransfer_v4
struct
BlockwiseDynamicTensorSliceTransfer_v4
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
263c5e41
...
@@ -126,6 +126,27 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -126,6 +126,27 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
bool
has_main_k_block_loop
=
(
K
+
KPerBlock
)
/
(
2
*
KPerBlock
)
>
1
;
return
has_main_k_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasDoubleTailKBlockLoop
(
index_t
K
)
{
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
return
has_double_tail_k_block_loop
;
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAKM0M1GridDescriptor
(
const
AKMGridDesc
&
a_k_m_grid_desc
)
MakeAKM0M1GridDescriptor
(
const
AKMGridDesc
&
a_k_m_grid_desc
)
{
{
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
263c5e41
...
@@ -482,23 +482,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -482,23 +482,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const
auto
in_gemmk_gemmn_grid_desc
=
descs
[
I1
];
const
auto
in_gemmk_gemmn_grid_desc
=
descs
[
I1
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// hack t
o
control index calculation when iterating over
wei_gemmk_gemmm_grid tensor
//
HACK:
hack
s
t
hat
control index calculation when iterating over
A, B, C matrix
constexpr
auto
wei_gemmk_gemmm_grid_iterator_hacks
=
constexpr
auto
wei_gemmk_gemmm_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over in_gemmk_gemmn_grid tensor
constexpr
auto
in_gemmk_gemmn_grid_iterator_hacks
=
constexpr
auto
in_gemmk_gemmn_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
=
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -513,6 +507,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -513,6 +507,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk_gemmn_grid_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_dynamic_gemm_v1r2
<
float
ave_time
=
driver_dynamic_gemm_v1r2
<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment