Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
35a57947
Commit
35a57947
authored
Oct 14, 2021
by
Jing Zhang
Browse files
add conv_out
parent
3e298e42
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
157 additions
and
6 deletions
+157
-6
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
...l/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
+103
-2
host/driver_offline/include/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+8
-2
host/driver_offline/include/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+45
-2
host/driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
.../driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
+1
-0
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
View file @
35a57947
...
@@ -17,6 +17,7 @@ template <typename GridwiseGemm,
...
@@ -17,6 +17,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
,
typename
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
bool
HasMainE0BlockLoop
>
...
@@ -28,9 +29,11 @@ __global__ void
...
@@ -28,9 +29,11 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_d_grid
,
FloatC
*
__restrict__
p_d_grid
,
const
AGridDesc_E0_E1_K0_K1_E2
a_e0_e1_k0_k1_e2_grid_desc
,
const
AGridDesc_E0_E1_K0_K1_E2
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
{
...
@@ -42,10 +45,12 @@ __global__ void
...
@@ -42,10 +45,12 @@ __global__ void
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_bias_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_d_grid
,
p_shared_block
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
...
@@ -59,6 +64,7 @@ template <typename GridwiseGemm,
...
@@ -59,6 +64,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
,
typename
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
bool
HasMainE0BlockLoop
>
...
@@ -69,11 +75,12 @@ __global__ void
...
@@ -69,11 +75,12 @@ __global__ void
kernel_gemm_dlops_v2_add
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_dlops_v2_add
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_d_grid
,
FloatC
*
__restrict__
p_d_grid
,
const
void
CONSTANT
*
p_a_e0_e1_k0_k1_e2_grid_desc
,
const
void
CONSTANT
*
p_a_e0_e1_k0_k1_e2_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
void
CONSTANT
*
p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
void
CONSTANT
*
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
void
CONSTANT
*
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
void
CONSTANT
*
p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
)
const
void
CONSTANT
*
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
{
// first cast void CONSTANT void* to void*
// first cast void CONSTANT void* to void*
...
@@ -84,6 +91,9 @@ __global__ void
...
@@ -84,6 +91,9 @@ __global__ void
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
*>
(
*
reinterpret_cast
<
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
*>
(
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
));
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
));
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
*>
(
cast_pointer_to_generic_address_space
(
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
));
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
*
reinterpret_cast
<
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
*>
(
*
reinterpret_cast
<
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
*>
(
cast_pointer_to_generic_address_space
(
p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
));
cast_pointer_to_generic_address_space
(
p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
));
...
@@ -99,10 +109,12 @@ __global__ void
...
@@ -99,10 +109,12 @@ __global__ void
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_bias_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_d_grid
,
p_shared_block
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
...
@@ -116,8 +128,8 @@ template <index_t BlockSize,
...
@@ -116,8 +128,8 @@ template <index_t BlockSize,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_E0_E1_K_E2
,
typename
AGridDesc_E0_E1_K_E2
,
typename
BGridDesc_E0_E1_N_Ho_Wo_E2
,
typename
BGridDesc_E0_E1_N_Ho_Wo_E2
,
typename
DGridDesc_K_N_Hox2_Wox2
,
typename
CGridDesc_K_N_Ho_Wo
,
typename
CGridDesc_K_N_Ho_Wo
,
typename
DGridDesc_K_N_Hox2_Wox2
,
index_t
E1_
,
index_t
E1_
,
index_t
E2_
,
index_t
E2_
,
index_t
K2_
,
index_t
K2_
,
...
@@ -146,6 +158,7 @@ template <index_t BlockSize,
...
@@ -146,6 +158,7 @@ template <index_t BlockSize,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobalStepHacks
,
typename
AGlobalStepHacks
,
typename
BGlobalStepHacks
,
typename
BGlobalStepHacks
,
typename
CGlobalStepHacks
,
typename
DGlobalStepHacks
,
typename
DGlobalStepHacks
,
typename
AGlobalMoveSliceWindowStepHacks
,
typename
AGlobalMoveSliceWindowStepHacks
,
typename
BGlobalMoveSliceWindowStepHacks
,
typename
BGlobalMoveSliceWindowStepHacks
,
...
@@ -283,6 +296,37 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
...
@@ -283,6 +296,37 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
return
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
;
return
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
;
}
}
__host__
__device__
static
constexpr
auto
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
const
auto
K
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Ho
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
H2
=
Number
<
HoPerThread
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Ho
/
(
H1
*
H2
);
const
auto
W2
=
Number
<
WoPerThread
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Wo
/
(
W1
*
W2
);
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
transform_tensor_descriptor
(
c_k_n_ho_wo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
const
DGridDesc_K_N_Hox2_Wox2
&
d_k_n_hox2_wox2_grid_desc
)
const
DGridDesc_K_N_Hox2_Wox2
&
d_k_n_hox2_wox2_grid_desc
)
{
{
...
@@ -339,8 +383,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
...
@@ -339,8 +383,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
decltype
(
MakeAE0E1K0K1E2GridDescriptor
(
AGridDesc_E0_E1_K_E2
{}));
decltype
(
MakeAE0E1K0K1E2GridDescriptor
(
AGridDesc_E0_E1_K_E2
{}));
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
decltype
(
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
BGridDesc_E0_E1_N_Ho_Wo_E2
{}));
decltype
(
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
BGridDesc_E0_E1_N_Ho_Wo_E2
{}));
using
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
=
decltype
(
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
CGridDesc_K_N_Ho_Wo
{}));
using
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
=
using
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
=
decltype
(
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
DGridDesc_K_N_Hox2_Wox2
{}));
decltype
(
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
DGridDesc_K_N_Hox2_Wox2
{}));
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
decltype
(
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
CGridDesc_K_N_Ho_Wo
{}));
decltype
(
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
CGridDesc_K_N_Ho_Wo
{}));
...
@@ -358,10 +405,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
...
@@ -358,10 +405,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_d_global
,
FloatC
*
__restrict__
p_d_global
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
&
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
&
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
...
@@ -382,6 +431,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
...
@@ -382,6 +431,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
.
GetElementSpaceSize
());
p_d_global
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
@@ -826,6 +877,56 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
...
@@ -826,6 +877,56 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
#endif
#endif
}
}
#if 1
// Output
{
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
// tensor
constexpr
auto
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
=
CGlobalStepHacks
{};
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThread
>
{},
I1
,
I1
,
Number
<
WoPerThread
>
{}));
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
),
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThread
,
I1
,
I1
,
WoPerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
))
.
Run
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_global_buf
,
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
);
}
#endif
// Resize_Add
// Resize_Add
{
{
constexpr
auto
HoPerThreadx2
=
HoPerThread
*
2
;
constexpr
auto
HoPerThreadx2
=
HoPerThread
*
2
;
...
...
host/driver_offline/include/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
35a57947
...
@@ -27,6 +27,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
...
@@ -27,6 +27,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
const
Tensor
<
TInWei
>&
in_n_c0_hi_wi_c1
,
const
Tensor
<
TInWei
>&
in_n_c0_hi_wi_c1
,
const
Tensor
<
TInWei
>&
wei_k_c0_y_x_c1
,
const
Tensor
<
TInWei
>&
wei_k_c0_y_x_c1
,
const
Tensor
<
TOut
>&
bias_k0_k1
,
const
Tensor
<
TOut
>&
bias_k0_k1
,
Tensor
<
TOut
>&
out_n_k0_ho_wo_k1
,
const
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1
,
const
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1
,
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1_out
,
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1_out
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
...
@@ -63,6 +64,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
...
@@ -63,6 +64,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c0_y_x_c1_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c0_y_x_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c0_y_x_c1_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c0_y_x_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_k0_k1_device_buf
(
sizeof
(
TOut
)
*
bias_k0_k1
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_k0_k1_device_buf
(
sizeof
(
TOut
)
*
bias_k0_k1
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k0_ho_wo_k1_device_buf
(
sizeof
(
TOut
)
*
out_n_k0_ho_wo_k1
.
mDesc
.
GetElementSpace
());
DeviceMem
add_n_k0_hox2_wox2_k1_device_buf
(
sizeof
(
TOut
)
*
DeviceMem
add_n_k0_hox2_wox2_k1_device_buf
(
sizeof
(
TOut
)
*
add_n_k0_hox2_wox2_k1
.
mDesc
.
GetElementSpace
());
add_n_k0_hox2_wox2_k1
.
mDesc
.
GetElementSpace
());
...
@@ -177,8 +180,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
...
@@ -177,8 +180,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
const
auto
ave_time
=
const
auto
ave_time
=
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
conv_strides
,
conv_strides
,
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
...
@@ -188,6 +191,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
...
@@ -188,6 +191,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k0_ho_wo_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
nrepeat
);
nrepeat
);
...
@@ -204,8 +208,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
...
@@ -204,8 +208,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
conv_strides
,
conv_strides
,
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
...
@@ -215,8 +219,10 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
...
@@ -215,8 +219,10 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k0_ho_wo_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
0
);
0
);
out_n_k0_ho_wo_k1_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
add_n_k0_hox2_wox2_k1_device_buf
.
FromDevice
(
add_n_k0_hox2_wox2_k1_out
.
mData
.
data
());
add_n_k0_hox2_wox2_k1_device_buf
.
FromDevice
(
add_n_k0_hox2_wox2_k1_out
.
mData
.
data
());
}
}
host/driver_offline/include/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
35a57947
...
@@ -40,8 +40,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -40,8 +40,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
typename
InRightPads
>
typename
InRightPads
>
__host__
float
Run
(
const
ck
::
TensorDescriptor
<
Wei
...
>&
wei_k_c0_y_x_c1_global_desc
,
__host__
float
Run
(
const
ck
::
TensorDescriptor
<
Wei
...
>&
wei_k_c0_y_x_c1_global_desc
,
const
ck
::
TensorDescriptor
<
In
...
>&
in_n_c0_hi_wi_c1_global_desc
,
const
ck
::
TensorDescriptor
<
In
...
>&
in_n_c0_hi_wi_c1_global_desc
,
const
ck
::
TensorDescriptor
<
Add
...
>&
add_n_k0_hox2_wox2_k1_global_desc
,
const
ck
::
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ck
::
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ck
::
TensorDescriptor
<
Add
...
>&
add_n_k0_hox2_wox2_k1_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
...
@@ -49,6 +49,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -49,6 +49,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
const
FloatC
*
__restrict__
p_bias_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_d_grid
,
FloatC
*
__restrict__
p_d_grid
,
const
int
nrepeat
)
const
const
int
nrepeat
)
const
{
{
...
@@ -247,6 +248,26 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -247,6 +248,26 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
=
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
=
constexpr
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -282,8 +303,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -282,8 +303,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
InMemoryDataOperationEnum_t
::
Add
,
InMemoryDataOperationEnum_t
::
Add
,
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
d_k_n_hopx2_wopx2_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
decltype
(
d_k_n_hopx2_wopx2_grid_desc
),
E1
,
E1
,
E2
,
E2
,
K2
,
K2
,
...
@@ -313,6 +334,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -313,6 +334,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
CThreadTransferDstScalarPerVector_K
,
CThreadTransferDstScalarPerVector_K
,
decltype
(
a_e0_e1_k_e2_global_step_hacks
),
decltype
(
a_e0_e1_k_e2_global_step_hacks
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
),
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
),
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
),
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
),
decltype
(
a_e0_e1_k_e2_global_move_slice_window_step_hack
),
decltype
(
a_e0_e1_k_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
),
...
@@ -322,12 +344,15 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -322,12 +344,15 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
GridwiseGemm
::
MakeAE0E1K0K1E2GridDescriptor
(
a_e0_e1_k_e2_grid_desc
);
GridwiseGemm
::
MakeAE0E1K0K1E2GridDescriptor
(
a_e0_e1_k_e2_grid_desc
);
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
GridwiseGemm
::
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
b_e0_e1_n_ho_wo_e2_grid_desc
);
GridwiseGemm
::
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
b_e0_e1_n_ho_wo_e2_grid_desc
);
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
GridwiseGemm
::
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
c_k_n_hop_wop_grid_desc
);
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
GridwiseGemm
::
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
d_k_n_hopx2_wopx2_grid_desc
);
GridwiseGemm
::
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
d_k_n_hopx2_wopx2_grid_desc
);
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
);
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
);
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
);
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
);
using
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
=
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
using
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
=
using
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
=
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
);
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
);
...
@@ -355,6 +380,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -355,6 +380,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
true
>
;
true
>
;
...
@@ -367,9 +393,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -367,9 +393,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_bias_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_d_grid
,
a_e0_e1_k0_k1_e2_grid_desc
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
}
}
...
@@ -381,6 +409,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -381,6 +409,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
false
>
;
false
>
;
...
@@ -393,9 +422,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -393,9 +422,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_bias_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_d_grid
,
a_e0_e1_k0_k1_e2_grid_desc
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
}
}
...
@@ -404,6 +435,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -404,6 +435,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
DeviceMem
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
(
sizeof
(
AGridDesc_E0_E1_K0_K1_E2
));
DeviceMem
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
(
sizeof
(
AGridDesc_E0_E1_K0_K1_E2
));
DeviceMem
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
(
DeviceMem
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
(
sizeof
(
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
));
sizeof
(
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
));
DeviceMem
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
(
sizeof
(
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
));
DeviceMem
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
(
DeviceMem
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
(
sizeof
(
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
));
sizeof
(
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
));
DeviceMem
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
(
DeviceMem
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
(
...
@@ -412,6 +445,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -412,6 +445,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
ToDevice
(
&
a_e0_e1_k0_k1_e2_grid_desc
);
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
ToDevice
(
&
a_e0_e1_k0_k1_e2_grid_desc
);
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
ToDevice
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
ToDevice
(
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
);
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
);
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
.
ToDevice
(
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
ToDevice
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
ToDevice
(
&
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
);
&
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
);
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
.
ToDevice
(
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
.
ToDevice
(
...
@@ -426,6 +461,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -426,6 +461,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
true
>
;
true
>
;
...
@@ -439,11 +475,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -439,11 +475,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_bias_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_d_grid
,
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
...
@@ -458,6 +497,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -458,6 +497,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
false
>
;
false
>
;
...
@@ -471,11 +511,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -471,11 +511,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_bias_grid
,
p_bias_grid
,
p_c_grid
,
p_d_grid
,
p_d_grid
,
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
...
...
host/driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
View file @
35a57947
...
@@ -308,6 +308,7 @@ int main(int argc, char* argv[])
...
@@ -308,6 +308,7 @@ int main(int argc, char* argv[])
in
,
in
,
wei
,
wei
,
bias
,
bias
,
out_device
,
add
,
add
,
add_device
,
add_device
,
nrepeat
);
nrepeat
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment