Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b79fce9
Commit
1b79fce9
authored
Oct 29, 2021
by
Jing Zhang
Browse files
create seperate fusion fun
parent
8e897da7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
319 additions
and
139 deletions
+319
-139
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v3.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v3.hpp
+311
-117
host/driver_offline/include/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+3
-7
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+1
-6
host/driver_offline/include/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+2
-7
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
+2
-2
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v3.hpp
View file @
1b79fce9
...
...
@@ -152,10 +152,9 @@ __global__ void
constexpr
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
CBlockIdToBlockClusterAdaptor_K_N_H_W
{};
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
ConvBiasActivResizeAdd
Run
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
nullptr
,
p_d_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
...
...
@@ -198,7 +197,7 @@ __global__ void
constexpr
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
CBlockIdToBlockClusterAdaptor_K_N_H_W
{};
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
ConvBiasActivMaxpool
Run
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_c_grid
,
...
...
@@ -241,16 +240,14 @@ __global__ void
constexpr
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
CBlockIdToBlockClusterAdaptor_K_N_H_W
{};
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
ConvBiasActiv
(
p_a_grid
,
p_b_grid
,
p_bias_grid
,
p_c_grid
,
nullptr
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
}
...
...
@@ -296,13 +293,10 @@ template <index_t BlockSize,
typename
CGlobalStepHacks
,
typename
DGlobalStepHacks
,
typename
AGlobalMoveSliceWindowStepHacks
,
typename
BGlobalMoveSliceWindowStepHacks
,
index_t
activ_type
=
0
,
index_t
bias_type
=
0
,
index_t
out_type
=
1
,
index_t
add_type
=
0
>
typename
BGlobalMoveSliceWindowStepHacks
>
struct
GridwiseGemmDlops_km_kn_mn_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -318,6 +312,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static
constexpr
FloatAcc
alpha
=
0.3
;
static
constexpr
auto
activ_type
=
I1
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
...
...
@@ -539,23 +535,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptor
(
const
DGridDesc_K_N_Hx_Wx
&
d_k_n_hx_wx_grid_desc
)
{
if
constexpr
(
add_type
==
1
)
{
return
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd
(
d_k_n_hx_wx_grid_desc
);
}
else
if
constexpr
(
add_type
==
2
)
{
return
MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool
(
d_k_n_hx_wx_grid_desc
);
}
else
{
return
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
d_k_n_hx_wx_grid_desc
);
}
}
__host__
__device__
static
constexpr
auto
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
...
...
@@ -584,18 +563,19 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
;
}
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
MakeAE0E1K0K1E2GridDescriptor
(
AGridDesc_E0_E1_K_E2
{}));
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
decltype
(
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
BGridDesc_E0_E1_N_Ho_Wo_E2
{}));
using
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
=
decltype
(
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
CGridDesc_K_N_Ho_Wo
{}));
using
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
=
decltype
(
MakeDK0K1NH0H1HxW0W1WxGridDescriptor
(
DGridDesc_K_N_Hx_Wx
{}));
//
using AGridDesc_E0_E1_K0_K1_E2 =
//
decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
//
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
//
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
//
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
//
decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
//
using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
//
decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
decltype
(
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
CGridDesc_K_N_Ho_Wo
{}));
template
<
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
__host__
__device__
static
constexpr
auto
MakeBiasK0K1GridDescriptor
(
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
)
{
...
...
@@ -728,18 +708,18 @@ struct GridwiseGemmDlops_km_kn_mn_v3
});
}
template
<
typename
CThreadBuff
,
typename
CThreadDesc_K1_N_H2_W2
>
__device__
static
void
Activation
(
CThreadBuff
&
c_thread_buf
,
const
CThreadDesc_K1_N_H2_W2
&
)
template
<
typename
CThreadBuff
,
typename
CThreadDesc_K1_N_H2_W2
,
index_t
activ_type_
>
__device__
static
void
Activation
(
CThreadBuff
&
c_thread_buf
,
const
CThreadDesc_K1_N_H2_W2
&
,
Number
<
activ_type_
>
)
{
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
CThreadDesc_K1_N_H2_W2
{};
static_for
<
0
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
activ_type
==
1
)
if
constexpr
(
activ_type
_
==
1
)
{
c_thread_buf
(
i
)
=
c_thread_buf
[
i
]
>=
0
?
c_thread_buf
[
i
]
:
alpha
*
c_thread_buf
[
i
];
c_thread_buf
(
i
)
=
c_thread_buf
[
i
]
>=
0
?
c_thread_buf
[
i
]
:
alpha
*
c_thread_buf
[
i
];
}
else
if
constexpr
(
activ_type
==
2
)
else
if
constexpr
(
activ_type
_
==
2
)
{
FloatAcc
x
=
1.0
+
exp
(
-
c_thread_buf
[
i
]);
...
...
@@ -1024,6 +1004,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename
CThreadBuff
,
typename
CBlockIndex
,
typename
CThreadIndex
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CThreadDesc_K1_N_H2_W2
,
bool
HasMainE0BlockLoop
>
__device__
static
void
...
...
@@ -1394,9 +1376,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
}
}
template
<
bool
HasMainE0BlockLoop
>
template
<
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
Conv
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_c_global
,
...
...
@@ -1437,6 +1424,69 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
auto
c_thread_mtx_index
=
GetCThreadIndex
();
// GemmOp
GemmOp
(
a_global_buf
,
b_global_buf
,
c_thread_buf
,
p_shared_block
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
// Output
WriteOut
(
c_thread_buf
,
c_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
}
template
<
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__device__
static
void
ConvBiasActiv
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_c_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
{
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
const
auto
c_k_n_h_w_block_cluster_idx
=
GetCBlockIndex
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
const
auto
c_thread_mtx_index
=
GetCThreadIndex
();
// GemmOp
GemmOp
(
a_global_buf
,
b_global_buf
,
...
...
@@ -1450,7 +1500,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
// Bias
if
constexpr
(
bias_type
>
0
)
BiasOp
(
bias_global_buf
,
c_thread_buf
,
c_k_n_h_w_block_cluster_idx
,
...
...
@@ -1459,26 +1508,94 @@ struct GridwiseGemmDlops_km_kn_mn_v3
c_k1_n_h2_w2_thread_gemm_desc
);
// Activ
if
constexpr
(
activ_type
>
0
)
Activation
(
c_thread_buf
,
c_k1_n_h2_w2_thread_gemm_desc
);
Activation
(
c_thread_buf
,
c_k1_n_h2_w2_thread_gemm_desc
,
activ_type
);
// Output
if
constexpr
(
out_type
>
0
)
WriteOut
(
c_thread_buf
,
c_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
}
if
constexpr
(
add_type
==
1
)
// Resize_Add
ResizeAdd
(
c_thread_buf
,
d_global_buf
,
template
<
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__device__
static
void
ConvBiasActivMaxpoolRun
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_d_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
{
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
const
auto
c_k_n_h_w_block_cluster_idx
=
GetCBlockIndex
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
const
auto
c_thread_mtx_index
=
GetCThreadIndex
();
// GemmOp
GemmOp
(
a_global_buf
,
b_global_buf
,
c_thread_buf
,
p_shared_block
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
);
else
if
constexpr
(
add_type
==
2
)
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
// Bias
BiasOp
(
bias_global_buf
,
c_thread_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
bias_k0_k1_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
);
// Activ
Activation
(
c_thread_buf
,
c_k1_n_h2_w2_thread_gemm_desc
,
activ_type
);
// Output
WriteOut
(
c_thread_buf
,
c_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
// MaxPool
MaxPool
(
c_thread_buf
,
d_global_buf
,
...
...
@@ -1487,6 +1604,83 @@ struct GridwiseGemmDlops_km_kn_mn_v3
c_k1_n_h2_w2_thread_gemm_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
);
}
template
<
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__device__
static
void
ConvBiasActivResizeAddRun
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatC
*
__restrict__
p_bias_global
,
FloatC
*
__restrict__
p_d_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
{
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
const
auto
c_k_n_h_w_block_cluster_idx
=
GetCBlockIndex
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
const
auto
c_thread_mtx_index
=
GetCThreadIndex
();
// GemmOp
GemmOp
(
a_global_buf
,
b_global_buf
,
c_thread_buf
,
p_shared_block
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
// Bias
BiasOp
(
bias_global_buf
,
c_thread_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
bias_k0_k1_grid_desc
,
c_k1_n_h2_w2_thread_gemm_desc
);
// Activ
Activation
(
c_thread_buf
,
c_k1_n_h2_w2_thread_gemm_desc
,
activ_type
);
// Resize_Add
ResizeAdd
(
c_thread_buf
,
d_global_buf
,
c_k_n_h_w_block_cluster_idx
,
c_thread_mtx_index
,
c_k1_n_h2_w2_thread_gemm_desc
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
);
}
};
}
// namespace ck
...
...
host/driver_offline/include/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
1b79fce9
...
...
@@ -336,12 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
),
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
),
decltype
(
a_e0_e1_k_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
),
activ_type
,
1
,
// bias_type
0
,
// out_type
1
// add_type
>
;
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
)
>
;
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
GridwiseGemm
::
MakeAE0E1K0K1E2GridDescriptor
(
a_e0_e1_k_e2_grid_desc
);
...
...
@@ -350,7 +345,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
GridwiseGemm
::
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
c_k_n_hop_wop_grid_desc
);
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
GridwiseGemm
::
MakeDK0K1NH0H1HxW0W1WxGridDescriptor
(
d_k_n_hopx2_wopx2_grid_desc
);
GridwiseGemm
::
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd
(
d_k_n_hopx2_wopx2_grid_desc
);
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
);
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
...
...
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
1b79fce9
...
...
@@ -301,12 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
),
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
),
decltype
(
a_e0_e1_k_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
),
activ_type
,
1
,
// bias_type
1
,
// out_type
0
// add_type
>
;
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
)
>
;
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
GridwiseGemm
::
MakeAE0E1K0K1E2GridDescriptor
(
a_e0_e1_k_e2_grid_desc
);
...
...
host/driver_offline/include/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
1b79fce9
...
...
@@ -340,12 +340,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
),
decltype
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks
),
decltype
(
a_e0_e1_k_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
),
activ_type
,
1
,
// bias_type
1
,
// out_type
2
// add_type
>
;
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
)
>
;
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
GridwiseGemm
::
MakeAE0E1K0K1E2GridDescriptor
(
a_e0_e1_k_e2_grid_desc
);
...
...
@@ -354,7 +349,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
GridwiseGemm
::
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
c_k_n_hop_wop_grid_desc
);
const
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
=
GridwiseGemm
::
MakeDK0K1NH0H1HxW0W1WxGridDescriptor
(
d_k_n_hx_wx_grid_desc
);
GridwiseGemm
::
MakeDK0K1NH0H1HxW0W1WxGridDescriptor
MaxPool
(
d_k_n_hx_wx_grid_desc
);
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
);
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
...
...
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
View file @
1b79fce9
...
...
@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
const
bool
do_log
=
std
::
stoi
(
argv
[
4
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
5
]);
constexpr
ck
::
ActivTypeEnum_t
activ_type
=
ActivTypeEnum_t
::
Sigmoid
;
//
constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::LeakyRelu;
//
constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::Sigmoid;
constexpr
ck
::
ActivTypeEnum_t
activ_type
=
ActivTypeEnum_t
::
LeakyRelu
;
#if 0
constexpr auto N = Number<1>{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment