Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
17cd5c7d
Commit
17cd5c7d
authored
Oct 02, 2021
by
Jing Zhang
Browse files
split h and w
parent
0e77b53e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
834 additions
and
234 deletions
+834
-234
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+206
-175
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+2
-2
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+626
-57
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
17cd5c7d
...
...
@@ -16,22 +16,22 @@ template <typename GridwiseGemm,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H
o_Wo
_E2
,
typename
CGridDesc_K_N_H
o_Wo
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H
o
_W
o
,
typename
BGridDesc_E0_E1_N_H
0_H1_H2_W0_W1_W2
_E2
,
typename
CGridDesc_K_N_H
0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_
b
_grid
,
Float
C
*
__restrict__
p_
c
_grid
,
const
AGridDesc_E0_E1_K0_K1_E2
a_e0_e1_k0_k1_e2_grid_desc
,
const
B
GridDesc_E0_E1_
N_Ho_Wo
_E2
b_e
0_
e
1_
n_ho_wo_e
2_grid_desc
,
const
C
GridDesc_
K_N_Ho_Wo
c_k_n_ho_wo
_grid_desc
,
const
C
BlockIdToBlockClusterAdaptor_K_N_Ho_Wo
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
)
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_
a
_grid
,
const
Float
AB
*
__restrict__
p_
b
_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
A
GridDesc_E0_E1_
K0_K1
_E2
A_E
0_
E
1_
K0_K1_E
2_grid_desc
,
const
B
GridDesc_
E0_E1_N_H0_H1_H2_W0_W1_W2_E2
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2
_grid_desc
,
const
C
GridDesc_K_N_H0_H1_H2_W0_W1_W2
c_k_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
@@ -43,9 +43,9 @@ __global__ void
p_c_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
c_k_n_h
o_wo
_grid_desc
,
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
,
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
c_k_n_h
0_h1_h2_w0_w1_w2
_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...
...
@@ -56,9 +56,9 @@ template <typename GridwiseGemm,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H
o_Wo
_E2
,
typename
CGridDesc_K_N_H
o_Wo
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H
o
_W
o
,
typename
BGridDesc_E0_E1_N_H
0_H1_H2_W0_W1_W2
_E2
,
typename
CGridDesc_K_N_H
0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -68,22 +68,24 @@ __global__ void
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_e0_e1_k0_k1_e2_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_h
o_wo
_e2_grid_desc
,
const
void
CONSTANT
*
p_c_k_n_h
o_wo
_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
)
const
void
CONSTANT
*
p_b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
const
void
CONSTANT
*
p_c_k_n_h
0_h1_h2_w0_w1_w2
_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
*
reinterpret_cast
<
const
AGridDesc_E0_E1_K0_K1_E2
*>
(
cast_pointer_to_generic_address_space
(
p_a_e0_e1_k0_k1_e2_grid_desc
));
const
auto
b_e0_e1_n_ho_wo_e2_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E0_E1_N_Ho_Wo_E2
*>
(
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_ho_wo_e2_grid_desc
));
const
auto
c_k_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K_N_Ho_Wo
*>
(
cast_pointer_to_generic_address_space
(
p_c_k_n_ho_wo_grid_desc
));
const
auto
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
*>
(
cast_pointer_to_generic_address_space
(
p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor
));
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
*>
(
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
));
const
auto
c_k_n_h0_h1_h2_w0_w1_w2_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K_N_H0_H1_H2_W0_W1_W2
*>
(
cast_pointer_to_generic_address_space
(
p_c_k_n_h0_h1_h2_w0_w1_w2_grid_desc
));
const
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
*>
(
cast_pointer_to_generic_address_space
(
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
));
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
@@ -95,9 +97,9 @@ __global__ void
p_c_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
c_k_n_h
o_wo
_grid_desc
,
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
,
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
c_k_n_h
0_h1_h2_w0_w1_w2
_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
}
#endif
...
...
@@ -232,7 +234,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
}
__host__
__device__
static
constexpr
auto
MakeCK0K1NH
oWo
GridDescriptor
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
MakeCK0K1NH
0H1H2W0W1W2
GridDescriptor
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
const
auto
K
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I1
);
...
...
@@ -242,19 +244,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
c_k0_k1_n_ho_wo_grid_desc
=
transform_tensor_descriptor
(
const
auto
H2
=
Number
<
HoPerThread
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Ho
/
(
H1
*
H2
);
const
auto
W2
=
Number
<
WoPerThread
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Wo
/
(
W1
*
W2
);
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
transform_tensor_descriptor
(
c_k_n_ho_wo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_
pass_through_transform
(
Ho
),
make_
pass_through_transform
(
Wo
)),
make_
unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)
),
make_
unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
)
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
c_k0_k1_n_h
o_wo
_grid_desc
;
return
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBE0E1NH0H1W0W1E2GridDescriptor
(
__host__
__device__
static
constexpr
auto
MakeBE0E1NH0H1
H2
W0W1
W2
E2GridDescriptor
(
const
BGridDesc_E0_E1_N_Ho_Wo_E2
&
b_e0_e1_n_ho_wo_e2_grid_desc
)
{
const
auto
E0
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I0
);
...
...
@@ -264,19 +274,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
auto
Wo
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I4
);
// const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5);
const
auto
H1
=
Number
<
HoPerBlock
>
{};
const
auto
H0
=
Ho
/
H1
;
const
auto
H2
=
Number
<
HoPerThread
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Ho
/
(
H1
*
H2
);
const
auto
W1
=
Number
<
WoPerBlock
>
{};
const
auto
W0
=
Wo
/
W1
;
const
auto
W2
=
Number
<
WoPerThread
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Wo
/
(
W1
*
W2
);
const
auto
b_e0_e1_n_h0_h1_w0_w1_e2_grid_desc
=
const
auto
b_e0_e1_n_h0_h1_
h2_
w0_w1_
w2_
e2_grid_desc
=
transform_tensor_descriptor
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
make_tuple
(
make_pass_through_transform
(
E0
),
make_pass_through_transform
(
E1
),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
)),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
)),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
...
...
@@ -287,11 +299,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{},
Sequence
<
9
>
{}));
return
b_e0_e1_n_h0_h1_w0_w1_e2_grid_desc
;
return
b_e0_e1_n_h0_h1_
h2_
w0_w1_
w2_
e2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -317,8 +329,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
MakeAE0E1K0K1E2GridDescriptor
(
AGridDesc_E0_E1_K_E2
{}));
using
CGridDesc_K0_K1_N_Ho_Wo
=
decltype
(
MakeCK0K1NHoWoGridDescriptor
(
CGridDesc_K_N_Ho_Wo
{}));
using
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
=
using
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
=
decltype
(
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
CGridDesc_K_N_Ho_Wo
{}));
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
decltype
(
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
BGridDesc_E0_E1_N_Ho_Wo_E2
{}));
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
decltype
(
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
CGridDesc_K_N_Ho_Wo
{}));
template
<
bool
HasMainE0BlockLoop
>
...
...
@@ -328,84 +343,70 @@ struct GridwiseGemmDlops_km_kn_mn_v3
FloatC
*
__restrict__
p_c_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H
o_Wo_E2
&
b_e0_e1_n_ho_wo
_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H
o_Wo
&
c_k0_k1_n_ho_wo
_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H
o
_W
o
&
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
,
const
BGridDesc_E0_E1_N_H
0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2
_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H
0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2
_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h
o_wo
_e2_grid_desc
.
GetElementSpaceSize
());
p_b_global
,
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h
o_wo
_grid_desc
.
GetElementSpaceSize
());
p_c_global
,
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
HasMainE1BlockLoop
=
CalculateHasMainE1BlockLoop
();
constexpr
auto
HasDoubleTailE1BlockLoop
=
CalculateHasDoubleTailE1BlockLoop
();
// const auto Ho = b_e0_e1_n_h
o_wo
_e2_grid_desc.GetLength(I3);
// const auto Wo = b_e0_e1_n_h
o_wo
_e2_grid_desc.GetLength(I4);
// const auto Ho = b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc.GetLength(I3);
// const auto Wo = b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc.GetLength(I4);
const
auto
c_k_n_h
o
_w
o
_block_cluster_idx
=
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
.
CalculateBottomIndex
(
const
auto
c_k_n_h_w_block_cluster_idx
=
c_blockid_to_k_n_h_w_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_k_n_h
o
_w
o
_block_cluster_idx
[
I0
]);
__builtin_amdgcn_readfirstlane
(
c_k_n_h_w_block_cluster_idx
[
I0
]);
const
index_t
n_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_k_n_h
o
_w
o
_block_cluster_idx
[
I1
]);
__builtin_amdgcn_readfirstlane
(
c_k_n_h_w_block_cluster_idx
[
I1
]);
const
index_t
ho_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_k_n_h
o
_w
o
_block_cluster_idx
[
I2
]);
__builtin_amdgcn_readfirstlane
(
c_k_n_h_w_block_cluster_idx
[
I2
]);
const
index_t
wo_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_k_n_h
o
_w
o
_block_cluster_idx
[
I3
]);
__builtin_amdgcn_readfirstlane
(
c_k_n_h_w_block_cluster_idx
[
I3
]);
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
// B matrix in thread, dst of blockwise copy
constexpr
auto
b_e1_n_ho_wo_e2_block_desc
=
constexpr
auto
a_e1_k1_e2_block_gemm_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
E1PerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
constexpr
auto
b_e1_n_h_w_e2_block_gemm_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
E1PerBlock
>
{},
Number
<
1
>
{}
,
I1
,
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{},
Number
<
E2
>
{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k1_n_ho_wo_thread_gemm_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
I1
,
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
constexpr
auto
a_e1_k1_e2_block_gemm_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
E1PerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_e1_k1_e2_block_gemm_desc
),
decltype
(
b_e1_n_h
o
_w
o
_e2_block_desc
),
decltype
(
c_k1_n_h
o
_w
o
_thread_gemm_desc
),
decltype
(
b_e1_n_h_w_e2_block_
gemm_
desc
),
decltype
(
c_k1_n_h
2
_w
2
_thread_gemm_desc
),
EPerThread
,
K2
>
{};
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
get_thread_local_1d_id
());
//
const auto k_thread_id = c_thread_mtx_index[I0];
const
auto
k_thread_id
=
c_thread_mtx_index
[
I0
];
const
auto
ho_thread_id
=
c_thread_mtx_index
[
I2
];
const
auto
wo_thread_id
=
c_thread_mtx_index
[
I3
];
// const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
// const index_t n_block_data_on_global = n_block_work_id * HoPerBlock;
const
index_t
ho_block_data_on_global
=
ho_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_on_global
=
wo_block_work_id
*
WoPerBlock
;
const
index_t
n_thread_data_on_global
=
0
;
const
index_t
ho_thread_data_on_global
=
ho_block_data_on_global
+
ho_thread_id
*
HoPerThread
;
const
index_t
wo_thread_data_on_global
=
wo_block_data_on_global
+
wo_thread_id
*
WoPerThread
;
constexpr
auto
a_e0_e1_k0_k1_e2_block_copy_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
I1
>
{},
Number
<
E1
>
{},
I1
,
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
...
...
@@ -438,30 +439,38 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
I1
,
0
,
0
,
0
,
0
);
constexpr
auto
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
=
constexpr
auto
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
E1PerBlock
>
{},
Number
<
1
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThread
>
{},
I1
,
I1
,
Number
<
WoPerThread
>
{},
Number
<
E2
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
),
decltype
(
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
),
Sequence
<
I1
,
E1PerBlock
,
NPerBlock
,
HoPerThread
,
WoPerThread
,
E2
>
,
decltype
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
),
decltype
(
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
),
Sequence
<
I1
,
E1PerBlock
,
I1
,
I1
,
I1
,
HoPerThread
,
I1
,
I1
,
WoPerThread
,
E2
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
true
>
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
make_multi_index
(
0
,
0
,
n_thread_data_on_global
,
ho_thread_data_on_global
,
wo_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
,
0
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
...
...
@@ -470,35 +479,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h
o
_w
o
_thread_gemm_desc
.
GetElementSpaceSize
(),
c_k1_n_h
2
_w
2
_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_k1_n_h
o
_w
o
_thread_gemm_desc
),
decltype
(
c_k1_n_h
2
_w
2
_thread_gemm_desc
),
Sequence
<
KPerThread
,
NPerBlock
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k1_n_h
o
_w
o
_thread_gemm_desc
,
.
Run
(
c_k1_n_h
2
_w
2
_thread_gemm_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
0
,
E1PerBlock
,
0
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
0
,
E1PerBlock
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_e0_e1_k_e2_global_step_hacks
=
AGlobalStepHacks
{};
constexpr
auto
b_e0_e1_n_h
o_wo
_e2_global_step_hacks
=
BGlobalStepHacks
{};
constexpr
auto
a_e0_e1_k_e2_global_step_hacks
=
AGlobalStepHacks
{};
constexpr
auto
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_step_hacks
=
BGlobalStepHacks
{};
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
,
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
.
GetElementSpaceSize
(),
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
.
GetElementSpaceSize
(),
true
>
b_thread_even_buf
,
b_thread_odd_buf
;
if
constexpr
(
HasMainE0BlockLoop
)
{
const
auto
E0
=
b_e0_e1_n_h
o_wo
_e2_grid_desc
.
GetLength
(
I0
);
const
auto
E0
=
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
.
GetLength
(
I0
);
index_t
e0_block_data_begin
=
0
;
...
...
@@ -509,12 +519,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
a_blockwise_copy
.
RunRead
(
a_e0_e1_k0_k1_e2_grid_desc
,
a_global_buf
,
a_e0_e1_k_e2_global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h
o_wo
_e2_global_step_hacks
);
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e0_e1_k0_k1_e2_block_copy_desc
,
a_block_buf
);
}
...
...
@@ -530,32 +540,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do
{
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_ho_wo_e2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_ho_wo_e2_global_step_hacks
);
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_ho_wo_e2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_ho_wo_e2_global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
...
...
@@ -570,16 +584,17 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailE1BlockLoop
)
// if has 2 iteration left
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h
o_wo
_e2_global_step_hacks
);
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_step_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
...
...
@@ -601,7 +616,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
-
(
E1
-
E1PerBlock
),
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
...
...
@@ -616,12 +631,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
a_blockwise_copy
.
RunRead
(
a_e0_e1_k0_k1_e2_grid_desc
,
a_global_buf
,
a_e0_e1_k_e2_global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h
o_wo
_e2_global_step_hacks
);
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e0_e1_k0_k1_e2_block_copy_desc
,
a_block_buf
);
}
...
...
@@ -637,32 +652,34 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do
{
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h
o_wo
_e2_global_step_hacks
);
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h
o_wo
_e2_global_step_hacks
);
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
...
...
@@ -677,16 +694,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailE1BlockLoop
)
// if has 2 iteration left
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h
o
_w
o
_e2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_e0_e1_n_h
0_h1_h2_w0_w1
_w
2
_e2_thread_
copy_
desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h
o_wo
_e2_global_step_hacks
);
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_step_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
...
...
@@ -705,7 +722,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// activ
{
static_for
<
0
,
c_k1_n_h
o
_w
o
_thread_gemm_desc
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
c_k1_n_h
2
_w
2
_thread_gemm_desc
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
activ_type
==
1
)
{
c_thread_buf
(
i
)
=
c_thread_buf
[
i
]
>=
0
?
c_thread_buf
[
i
]
:
0.0
;
...
...
@@ -726,36 +743,50 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr
auto
c_k_n_ho_wo_global_tensor_step_hacks
=
CGlobalStepHacks
{};
constexpr
auto
c_k0_k1_n_ho_wo_thread_copy_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
// tensor
constexpr
auto
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
=
CGlobalStepHacks
{};
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThread
>
{},
I1
,
I1
,
Number
<
WoPerThread
>
{}));
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_k0_k1_n_h
o_wo
_thread_copy_desc
),
decltype
(
c_k0_k1_n_h
o_wo
_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
HoPerThread
,
WoPerThread
>
,
decltype
(
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_thread_copy_desc
),
decltype
(
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThread
,
I1
,
I1
,
WoPerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
(
c_k0_k1_n_h
o_wo
_grid_desc
,
true
>
(
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
n_thread_data_on_global
,
h
o_thread_d
ata_on_global
,
wo_thread_data_on_global
))
.
Run
(
c_k0_k1_n_h
o_wo
_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
wo_block_work_id
,
w
o_thread_
i
d
,
0
))
.
Run
(
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_k0_k1_n_h
o_wo
_grid_desc
,
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_grid_desc
,
c_global_buf
,
c_k_n_h
o_wo
_global_tensor_step_hacks
);
c_k_n_h
0_h1_h2_w0_w1_w2
_global_tensor_step_hacks
);
}
}
};
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
17cd5c7d
...
...
@@ -119,14 +119,14 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
E1
=
2
*
9
;
constexpr
index_t
E2
=
1
;
constexpr
index_t
E1PerBlock
=
2
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
1
;
...
...
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
17cd5c7d
...
...
@@ -197,35 +197,597 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
constexpr
auto
a_e0_e1_k_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
b_e0_e1_n_ho_wo_e2_global_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr
auto
c_k
_n_ho_wo
_global_tensor_step_hacks
=
constexpr
auto
c_k
0_k1_n_h0_h1_h2_w0_w1_w2
_global_tensor_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
static_assert
(
a_e0_e1_k_e2_grid_desc
.
IsKnownAtCompileTime
(),
""
);
...
...
@@ -260,29 +822,32 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
ABlockTransferSrcScalarPerVector_E2
,
ABlockTransferDstScalarPerVector_E2
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
// E0, E1, N, H0, H1, H2, W0, W1, W2, E2
9
,
BThreadTransferSrcScalarPerVector_E2
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
>
,
// K0, K1, N, H0, H1, H2, W0, W1, W2
1
,
CThreadTransferDstScalarPerVector_K
,
decltype
(
a_e0_e1_k_e2_global_step_hacks
),
decltype
(
b_e0_e1_n_h
o_wo
_e2_global_step_hacks
),
decltype
(
c_k
_n_ho_wo
_global_tensor_step_hacks
),
decltype
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_step_hacks
),
decltype
(
c_k
0_k1_n_h0_h1_h2_w0_w1_w2
_global_tensor_step_hacks
),
decltype
(
a_e0_e1_k_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_h
o_wo
_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_global_move_slice_window_step_hack
),
activ_type
>
;
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
GridwiseGemm
::
MakeAE0E1K0K1E2GridDescriptor
(
a_e0_e1_k_e2_grid_desc
);
const
auto
c_k0_k1_n_hop_wop_grid_desc
=
GridwiseGemm
::
MakeCK0K1NHoWoGridDescriptor
(
c_k_n_hop_wop_grid_desc
);
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
GridwiseGemm
::
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
b_e0_e1_n_ho_wo_e2_grid_desc
);
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
GridwiseGemm
::
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
c_k_n_hop_wop_grid_desc
);
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
);
using
BGridDesc_E0_E1_N_Ho_Wo_E2
=
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
);
using
CGridDesc_K0_K1_N_Ho_Wo
=
decltype
(
c_k0_k1_n_hop_wop_grid_desc
);
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
);
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
);
using
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
=
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
grid_size
=
(
K
/
KPerBlock
)
*
(
Hop
/
HoPerBlock
)
*
(
Wop
/
WoPerBlock
)
*
N
;
...
...
@@ -290,11 +855,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
std
::
cerr
<<
"has_main_e0_block_loop = "
<<
has_main_e0_block_loop
<<
std
::
endl
;
const
auto
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
=
const
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
c_k_n_hop_wop_grid_desc
);
using
CBlockIdToBlockClusterAdaptor_K_N_H
o
_W
o
=
decltype
(
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
);
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
decltype
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
float
ave_time
=
0
;
...
...
@@ -304,9 +869,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H
o_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H
o_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H
o
_W
o
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H
0_H1_H2_W0_W1_W2
_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H
0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
has_main_e0_block_loop
,
has_main_e1_block_loop
,
has_double_tail_e1_block_loop
>
;
...
...
@@ -320,22 +885,26 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid
,
p_c_grid
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h
o_wo
_e2_grid_desc
,
c_k0_k1_n_h
op_wop
_grid_desc
,
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor
);
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc
,
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
(
sizeof
(
AGridDesc_E0_E1_K0_K1_E2
));
DeviceMem
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
(
sizeof
(
BGridDesc_E0_E1_N_Ho_Wo_E2
));
DeviceMem
c_k0_k1_n_hop_wop_grid_desc_dev_buf
(
sizeof
(
CGridDesc_K0_K1_N_Ho_Wo
));
DeviceMem
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
));
DeviceMem
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
(
sizeof
(
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
));
DeviceMem
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
(
sizeof
(
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
));
DeviceMem
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockIdToBlockClusterAdaptor_K_N_H_W
));
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
ToDevice
(
&
a_e0_e1_k0_k1_e2_grid_desc
);
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
ToDevice
(
&
b_e0_e1_n_ho_wo_e2_grid_desc
);
c_k0_k1_n_hop_wop_grid_desc_dev_buf
.
ToDevice
(
&
c_k0_k1_n_hop_wop_grid_desc
);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
ToDevice
(
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
);
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
.
ToDevice
(
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
if
(
has_main_e0_block_loop
)
{
...
...
@@ -345,9 +914,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H
o_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H
o_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H
o
_W
o
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H
0_H1_H2_W0_W1_W2
_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H
0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
...
...
@@ -362,11 +931,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
cast_pointer_to_constant_address_space
(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_h
o_wo
_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k0_k1_n_h
op_wop
_grid_desc_dev_buf
.
GetDeviceBuffer
()),
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
...
...
@@ -376,9 +945,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H
o_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H
o_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H
o
_W
o
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H
0_H1_H2_W0_W1_W2
_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H
0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
...
...
@@ -393,11 +962,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
cast_pointer_to_constant_address_space
(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_h
o_wo
_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
b_e0_e1_n_h
0_h1_h2_w0_w1_w2
_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k0_k1_n_h
op_wop
_grid_desc_dev_buf
.
GetDeviceBuffer
()),
c_k0_k1_n_h
0_h1_h2_w0_w1_w2
_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_h
o
_w
o
_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment