Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0f276ac2
Commit
0f276ac2
authored
Oct 14, 2021
by
Jing Zhang
Browse files
add configurable makeddesc
parent
35a57947
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
124 additions
and
78 deletions
+124
-78
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
...l/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
+123
-77
host/driver_offline/include/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
View file @
0f276ac2
...
...
@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_H
2x2
_W0_W1_W
2x2
,
typename
DGridDesc_K0_K1_N_H0_H1_H
x
_W0_W1_W
x
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__global__
void
...
...
@@ -34,7 +34,7 @@ __global__ void
const
AGridDesc_E0_E1_K0_K1_E2
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H
2x2
_W0_W1_W
2x2
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H
x
_W0_W1_W
x
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
...
...
@@ -51,7 +51,7 @@ __global__ void
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
,
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
}
...
...
@@ -65,7 +65,7 @@ template <typename GridwiseGemm,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
DGridDesc_K0_K1_N_H0_H1_H
2x2
_W0_W1_W
2x2
,
typename
DGridDesc_K0_K1_N_H0_H1_H
x
_W0_W1_W
x
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__global__
void
...
...
@@ -80,7 +80,7 @@ __global__ void
const
void
CONSTANT
*
p_a_e0_e1_k0_k1_e2_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
void
CONSTANT
*
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
void
CONSTANT
*
p_d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
,
const
void
CONSTANT
*
p_d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
// first cast void CONSTANT void* to void*
...
...
@@ -94,9 +94,9 @@ __global__ void
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
*>
(
cast_pointer_to_generic_address_space
(
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
));
const
auto
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
=
*
reinterpret_cast
<
const
DGridDesc_K0_K1_N_H0_H1_H
2x2
_W0_W1_W
2x2
*>
(
cast_pointer_to_generic_address_space
(
p_d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
));
const
auto
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
=
*
reinterpret_cast
<
const
DGridDesc_K0_K1_N_H0_H1_H
x
_W0_W1_W
x
*>
(
cast_pointer_to_generic_address_space
(
p_d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
));
const
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
*>
(
cast_pointer_to_generic_address_space
(
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
));
...
...
@@ -115,7 +115,7 @@ __global__ void
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
,
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
}
...
...
@@ -129,7 +129,7 @@ template <index_t BlockSize,
typename
AGridDesc_E0_E1_K_E2
,
typename
BGridDesc_E0_E1_N_Ho_Wo_E2
,
typename
CGridDesc_K_N_Ho_Wo
,
typename
DGridDesc_K_N_H
ox2_Wox2
,
typename
DGridDesc_K_N_H
x_Wx
,
index_t
E1_
,
index_t
E2_
,
index_t
K2_
,
...
...
@@ -162,7 +162,8 @@ template <index_t BlockSize,
typename
DGlobalStepHacks
,
typename
AGlobalMoveSliceWindowStepHacks
,
typename
BGlobalMoveSliceWindowStepHacks
,
index_t
activ_type
=
0
>
index_t
activ_type
=
0
,
index_t
add_type
=
0
>
struct
GridwiseGemmDlops_km_kn_mn_v3_add
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -327,27 +328,58 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
return
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
const
DGridDesc_K_N_H
ox2_Wox2
&
d_k_n_h
ox2_wox2
_grid_desc
)
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool
(
const
DGridDesc_K_N_H
x_Wx
&
d_k_n_h
x_wx
_grid_desc
)
{
const
auto
K
=
d_k_n_hox2_wox2_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
d_k_n_hox2_wox2_grid_desc
.
GetLength
(
I1
);
const
auto
Hox2
=
d_k_n_hox2_wox2_grid_desc
.
GetLength
(
I2
);
const
auto
Wox2
=
d_k_n_hox2_wox2_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I1
);
const
auto
Hx
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I2
);
const
auto
Wx
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I3
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
H2
=
HoPerThread
/
2
;
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Hx
/
(
H1
*
H2
);
const
auto
W2
=
WoPerThread
/
2
;
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Wx
/
(
W1
*
W2
);
const
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
=
transform_tensor_descriptor
(
d_k_n_hx_wx_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd
(
const
DGridDesc_K_N_Hx_Wx
&
d_k_n_hx_wx_grid_desc
)
{
const
auto
K
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I1
);
const
auto
Hx
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I2
);
const
auto
Wx
=
d_k_n_hx_wx_grid_desc
.
GetLength
(
I3
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
H2
=
HoPerThread
*
2
;
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
H
ox2
/
(
H1
*
H2
);
const
auto
H0
=
H
x
/
(
H1
*
H2
);
const
auto
W2
=
WoPerThread
*
2
;
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
W
ox2
/
(
W1
*
W2
);
const
auto
W0
=
W
x
/
(
W1
*
W2
);
const
auto
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
=
transform_tensor_descriptor
(
d_k_n_h
ox2_wox2
_grid_desc
,
const
auto
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
=
transform_tensor_descriptor
(
d_k_n_h
x_wx
_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
...
...
@@ -355,7 +387,24 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
;
return
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptor
(
const
DGridDesc_K_N_Hx_Wx
&
d_k_n_hx_wx_grid_desc
)
{
if
constexpr
(
add_type
==
0
)
{
return
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd
(
d_k_n_hx_wx_grid_desc
);
}
else
if
constexpr
(
add_type
==
1
)
{
return
MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool
(
d_k_n_hx_wx_grid_desc
);
}
else
{
return
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
d_k_n_hx_wx_grid_desc
);
}
}
__host__
__device__
static
constexpr
auto
...
...
@@ -385,17 +434,17 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
decltype
(
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
BGridDesc_E0_E1_N_Ho_Wo_E2
{}));
using
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
=
decltype
(
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
CGridDesc_K_N_Ho_Wo
{}));
using
DGridDesc_K0_K1_N_H0_H1_H
2x2
_W0_W1_W
2x2
=
decltype
(
MakeDK0K1NH0H1H
2x2
W0W1W
2x2
GridDescriptor
(
DGridDesc_K_N_H
ox2_Wox2
{}));
using
DGridDesc_K0_K1_N_H0_H1_H
x
_W0_W1_W
x
=
decltype
(
MakeDK0K1NH0H1H
x
W0W1W
x
GridDescriptor
(
DGridDesc_K_N_H
x_Wx
{}));
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
decltype
(
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
CGridDesc_K_N_Ho_Wo
{}));
__host__
__device__
static
constexpr
auto
MakeBiasK0K1GridDescriptor
(
const
DGridDesc_K0_K1_N_H0_H1_H
2x2
_W0_W1_W
2x2
&
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
)
const
DGridDesc_K0_K1_N_H0_H1_H
x
_W0_W1_W
x
&
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
)
{
const
auto
K0
=
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
.
GetLength
(
I0
);
const
auto
K1
=
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
.
GetLength
(
I0
);
const
auto
K1
=
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
.
GetLength
(
I1
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K0
,
K1
));
}
...
...
@@ -411,7 +460,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H
2x2
_W0_W1_W
2x2
&
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H
x
_W0_W1_W
x
&
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
{
...
...
@@ -419,13 +468,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
// constexpr auto a_e0_e1_k0_k1_e2_grid_desc = AGridDesc_E0_E1_K0_K1_E2{};
// constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
// BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2{};
// constexpr auto d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc =
// DGridDesc_K0_K1_N_H0_H1_H
2x2
_W0_W1_W
2x2
{};
// constexpr auto d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc =
// DGridDesc_K0_K1_N_H0_H1_H
x
_W0_W1_W
x
{};
// constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
// CBlockIdToBlockClusterAdaptor_K_N_H_W{};
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
);
MakeBiasK0K1GridDescriptor
(
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
...
...
@@ -434,7 +483,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
.
GetElementSpaceSize
());
p_d_global
,
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
...
...
@@ -933,7 +982,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
constexpr
auto
WoPerThreadx2
=
WoPerThread
*
2
;
#if 1
constexpr
auto
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_thread_desc
=
constexpr
auto
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
...
...
@@ -946,16 +995,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatC
,
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_thread_desc
.
GetElementSpaceSize
(),
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_thread_desc
.
GetElementSpaceSize
(),
true
>
d_thread_buf
;
static_for
<
0
,
KPerThread
,
1
>
{}([
&
](
auto
k_i
)
{
static_for
<
0
,
HoPerThreadx2
,
1
>
{}([
&
](
auto
h_i
)
{
static_for
<
0
,
WoPerThreadx2
,
1
>
{}([
&
](
auto
w_i
)
{
d_thread_buf
(
Number
<
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
k_i
,
0
,
0
,
0
,
h_i
,
0
,
0
,
w_i
))
>
{})
=
d_thread_buf
(
Number
<
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
k_i
,
0
,
0
,
0
,
h_i
,
0
,
0
,
w_i
))
>
{})
=
c_thread_buf
[
Number
<
c_k1_n_h2_w2_thread_gemm_desc
.
CalculateOffset
(
make_tuple
(
k_i
,
0
,
h_i
/
2
,
w_i
/
2
))
>
{}];
});
...
...
@@ -974,58 +1022,56 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
I1
,
Number
<
WoPerThread
>
{}));
constexpr
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
=
transform_tensor_descriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_desc
,
make_tuple
(
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
KPerThread
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_embed_transform
(
make_tuple
(
I2
,
Number
<
HoPerThread
>
{}),
make_tuple
(
I0
,
I1
)),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_embed_transform
(
make_tuple
(
I2
,
Number
<
WoPerThread
>
{}),
make_tuple
(
I0
,
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{},
Sequence
<
9
,
10
>
{}));
constexpr
auto
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
=
transform_tensor_descriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_desc
,
make_tuple
(
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
KPerThread
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_embed_transform
(
make_tuple
(
I2
,
Number
<
HoPerThread
>
{}),
make_tuple
(
I0
,
I1
)),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_embed_transform
(
make_tuple
(
I2
,
Number
<
WoPerThread
>
{}),
make_tuple
(
I0
,
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{},
Sequence
<
9
,
10
>
{}));
#endif
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr
auto
d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
=
DGlobalStepHacks
{};
constexpr
auto
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks
=
DGlobalStepHacks
{};
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
FloatC
,
decltype
(
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_thread_desc
),
decltype
(
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
),
decltype
(
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_thread_desc
),
decltype
(
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThreadx2
,
I1
,
I1
,
WoPerThreadx2
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
(
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
,
true
>
(
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
...
...
@@ -1035,12 +1081,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
wo_block_work_id
,
wo_thread_id
,
0
))
.
Run
(
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_thread_desc
,
.
Run
(
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d_thread_buf
,
d_k0_k1_n_h0_h1_h
2x2
_w0_w1_w
2x2
_grid_desc
,
d_k0_k1_n_h0_h1_h
x
_w0_w1_w
x
_grid_desc
,
d_global_buf
,
d_k_n_h0_h1_h
2x2
_w0_w1_w
2x2
_global_tensor_step_hacks
);
d_k_n_h0_h1_h
x
_w0_w1_w
x
_global_tensor_step_hacks
);
}
}
};
...
...
host/driver_offline/include/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
0f276ac2
...
...
@@ -347,7 +347,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
GridwiseGemm
::
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
c_k_n_hop_wop_grid_desc
);
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
GridwiseGemm
::
MakeDK0K1NH0H1H
2x2
W0W1W
2x2
GridDescriptor
(
d_k_n_hopx2_wopx2_grid_desc
);
GridwiseGemm
::
MakeDK0K1NH0H1H
x
W0W1W
x
GridDescriptor
(
d_k_n_hopx2_wopx2_grid_desc
);
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
);
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment