Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
149296c0
Commit
149296c0
authored
Sep 13, 2021
by
ltqin
Browse files
add MakeCGM0N0M1N1M2M3M4N2GridDescriptor
parent
973978aa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
181 additions
and
122 deletions
+181
-122
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp
...orward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp
+11
-10
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+16
-0
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+37
-0
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
...forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
+16
-16
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
+89
-84
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+12
-12
No files found.
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp
View file @
149296c0
...
@@ -90,15 +90,16 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
...
@@ -90,15 +90,16 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{}));
const
auto
in_gemmg_gemmk_gemmm_grid_desc
=
const
auto
in_gemmg_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_g_n_y_ho_x_wo_c_grid_desc
,
in_g_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
G
),
make_tuple
(
make_pass_through_transform
(
G
),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_desc
=
const
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmg_gemmk_gemmm_grid_desc
,
transform_tensor_descriptor
(
in_gemmg_gemmk_gemmm_grid_desc
,
...
@@ -112,7 +113,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
...
@@ -112,7 +113,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
const
auto
wei_gemmg_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemmg_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
G
,
K
,
Y
*
X
*
C
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
G
,
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
G
),
make_tuple
(
make_pass_through_transform
(
G
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
>
{}));
...
@@ -129,7 +130,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
...
@@ -129,7 +130,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
const
auto
out_gemmg_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmg_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
G
,
K
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
G
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
G
),
make_pass_through_transform
(
G
),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}));
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
149296c0
...
@@ -158,6 +158,22 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -158,6 +158,22 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
);
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
);
}
}
template
<
typename
CGMNGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
const
CGMNGridDesc
&
c_g_m_n_grid_desc
)
{
const
auto
G
=
c_g_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
c_g_m0_n0_m1_n1_m2_n2_grid_desc
=
transform_tensor_descriptor
(
c_g_m_n_grid_desc
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCGM0N0M1N1M2M3M4N2Descriptor
(
c_g_m0_n0_m1_n1_m2_n2_grid_desc
);
}
__host__
__device__
static
constexpr
auto
MakeAK0M0M1M2K1BlockDescriptor
()
__host__
__device__
static
constexpr
auto
MakeAK0M0M1M2K1BlockDescriptor
()
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
149296c0
...
@@ -727,6 +727,43 @@ struct XdlopsGemm
...
@@ -727,6 +727,43 @@ struct XdlopsGemm
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
}
}
template
<
typename
CGM0N0M1N1M2N2Desc
>
__host__
__device__
static
constexpr
auto
MakeCGM0N0M1N1M2M3M4N2Descriptor
(
const
CGM0N0M1N1M2N2Desc
&
c_g_m0_n0_m1_n1_m2_n2_desc
)
{
const
auto
G
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I0
);
const
auto
M0
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I1
);
const
auto
N0
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I2
);
const
auto
M1
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I3
);
const
auto
N1
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_g_m0_n0_m1_n1_m2_n2_desc
,
make_tuple
(
make_pass_through_transform
(
G
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
mfma_instr
.
num_groups_per_blk
,
mfma_instr
.
num_input_blks
,
mfma_instr
.
group_size
)),
make_pass_through_transform
(
mfma_instr
.
num_threads_per_blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{},
Sequence
<
8
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
{
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
View file @
149296c0
...
@@ -221,28 +221,28 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
...
@@ -221,28 +221,28 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
const
auto
descs
=
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad
(
in_n_hi_wi_g_c_desc
,
transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad
(
in_n_hi_wi_g_c_desc
,
wei_g_k_y_x_c_desc
,
wei_g_k_y_x_c_desc
,
out_n_ho_wo_g_k_desc
,
out_n_ho_wo_g_k_desc
,
conv_strides
,
conv_strides
,
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
in_right_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
Number
<
GemmK1
>
{});
const
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
out_gemmg_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmg_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks
=
constexpr
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmG
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmG
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmG
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmG
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks
=
constexpr
auto
wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmG
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmG
...
...
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
View file @
149296c0
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v
2r3
.hpp"
#include "gridwise_gemm_xdlops_v
3r1
.hpp"
template
<
ck
::
index_t
BlockSize
,
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
...
@@ -22,16 +22,16 @@ template <ck::index_t BlockSize,
...
@@ -22,16 +22,16 @@ template <ck::index_t BlockSize,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
MRepeat
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadSliceLengths_
G_
K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_
G_
K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadSliceLengths_
G_
K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_
G_
K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
...
@@ -50,9 +50,9 @@ template <ck::index_t BlockSize,
...
@@ -50,9 +50,9 @@ template <ck::index_t BlockSize,
__host__
float
driver_gemm_xdlops_v3r1
(
const
FloatAB
*
p_a_grid
,
__host__
float
driver_gemm_xdlops_v3r1
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
FloatC
*
p_c_grid
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
AK0MK1GridDesc
&
a_
g_
k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_
g_
k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
CMNGridDesc
&
c_
g_
m_n_grid_desc
,
AGridStepHacks
,
AGridStepHacks
,
BGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
CGridStepHacks
,
...
@@ -66,9 +66,10 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
...
@@ -66,9 +66,10 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
/*
using GridwiseGemm =
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v
2r3
<BlockSize,
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v
3r1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
...
@@ -84,16 +85,16 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
...
@@ -84,16 +85,16 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
K1
,
K1
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadSliceLengths_
G_
K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_
G_
K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadSliceLengths_
G_
K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_
G_
K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
...
@@ -111,84 +112,88 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
...
@@ -111,84 +112,88 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
CAccessOrderMRepeatNRepeat
>
;
CAccessOrderMRepeatNRepeat
>
;
{
{
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
std
::
cout
<<
"a_g_k0_m_k1_grid_desc{"
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_k0_n_k1_grid_desc{"
<<
b_g_k0_n_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
b_g_k0_n_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
b_g_k0_n_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
b_g_k0_n_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_g_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_g_m_n_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
c_g_m_n_grid_desc
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
"}"
<<
std
::
endl
;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_g_k0_m_k1_grid_desc
,
b_g_k0_n_k1_grid_desc
,
c_g_m_n_grid_desc
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
}
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
const
auto
c_
gemmg_
m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_
g_
m_n_grid_desc
);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
/*
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB,
FloatAB,
FloatC,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel,
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
nrepeat,
dim3(grid_size),
dim3(grid_size),
dim3(BlockSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_a_grid,
p_b_grid,
p_b_grid,
p_c_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel(
float ave_time = launch_and_time_kernel(
kernel,
kernel,
nrepeat,
nrepeat,
dim3(grid_size),
dim3(grid_size),
dim3(BlockSize),
dim3(BlockSize),
0,
0,
p_a_grid,
p_a_grid,
p_b_grid,
p_b_grid,
p_c_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif
#endif
return ave_time;*/
return ave_time;*/
return
0.0
;
return
0.0
;
}
}
#endif
#endif
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
149296c0
...
@@ -36,8 +36,8 @@ enum ConvForwardAlgo
...
@@ -36,8 +36,8 @@ enum ConvForwardAlgo
V6R1NCHW
,
// 2
V6R1NCHW
,
// 2
V5R1NCHW
,
// 3
V5R1NCHW
,
// 3
V4R4R2XDLNCHW
,
// 4
V4R4R2XDLNCHW
,
// 4
V4R4R4XDLNHWC
,
// 5
V4R4R4XDLNHWC
,
// 5
V4R4R4XDLNHWGC
// 6
V4R4R4XDLNHWGC
// 6
};
};
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
@@ -68,10 +68,10 @@ int main(int argc, char* argv[])
...
@@ -68,10 +68,10 @@ int main(int argc, char* argv[])
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
index_t
G
=
1
;
index_t
G
=
1
;
const
index_t
N
=
std
::
stoi
(
argv
[
7
]);
const
index_t
N
=
std
::
stoi
(
argv
[
7
]);
index_t
K
=
std
::
stoi
(
argv
[
8
]);
index_t
K
=
std
::
stoi
(
argv
[
8
]);
index_t
C
=
std
::
stoi
(
argv
[
9
]);
index_t
C
=
std
::
stoi
(
argv
[
9
]);
const
index_t
Y
=
std
::
stoi
(
argv
[
10
]);
const
index_t
Y
=
std
::
stoi
(
argv
[
10
]);
const
index_t
X
=
std
::
stoi
(
argv
[
11
]);
const
index_t
X
=
std
::
stoi
(
argv
[
11
]);
const
index_t
Hi
=
std
::
stoi
(
argv
[
12
]);
const
index_t
Hi
=
std
::
stoi
(
argv
[
12
]);
...
@@ -85,12 +85,12 @@ int main(int argc, char* argv[])
...
@@ -85,12 +85,12 @@ int main(int argc, char* argv[])
const
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
19
]);
const
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
19
]);
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
if
(
argc
==
23
){
if
(
argc
==
23
)
G
=
std
::
stoi
(
argv
[
22
]);
{
K
=
K
/
G
;
G
=
std
::
stoi
(
argv
[
22
]);
C
=
C
/
G
;
K
=
K
/
G
;
C
=
C
/
G
;
}
}
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
...
@@ -480,8 +480,8 @@ int main(int argc, char* argv[])
...
@@ -480,8 +480,8 @@ int main(int argc, char* argv[])
const
auto
tmp
=
f_make_for_device_nhwgc
();
const
auto
tmp
=
f_make_for_device_nhwgc
();
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk
<
in_data_t
,
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk
<
in_data_t
,
acc_data_t
,
acc_data_t
,
out_data_t
>
(
out_data_t
>
(
tmp
[
I0
],
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I2
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment