Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
149296c0
Commit
149296c0
authored
Sep 13, 2021
by
ltqin
Browse files
add MakeCGM0N0M1N1M2M3M4N2GridDescriptor
parent
973978aa
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
181 additions
and
122 deletions
+181
-122
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp
...orward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp
+11
-10
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+16
-0
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+37
-0
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
...forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
+16
-16
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
+89
-84
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+12
-12
No files found.
composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk.hpp
View file @
149296c0
...
...
@@ -90,10 +90,11 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{}));
const
auto
in_gemmg_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_g_n_y_ho_x_wo_c_grid_desc
,
const
auto
in_gemmg_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_g_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
G
),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
149296c0
...
...
@@ -158,6 +158,22 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
);
}
template
<
typename
CGMNGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
const
CGMNGridDesc
&
c_g_m_n_grid_desc
)
{
const
auto
G
=
c_g_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
c_g_m0_n0_m1_n1_m2_n2_grid_desc
=
transform_tensor_descriptor
(
c_g_m_n_grid_desc
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCGM0N0M1N1M2M3M4N2Descriptor
(
c_g_m0_n0_m1_n1_m2_n2_grid_desc
);
}
__host__
__device__
static
constexpr
auto
MakeAK0M0M1M2K1BlockDescriptor
()
{
return
transform_tensor_descriptor
(
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
149296c0
...
...
@@ -727,6 +727,43 @@ struct XdlopsGemm
Sequence
<
7
>
{}));
}
template
<
typename
CGM0N0M1N1M2N2Desc
>
__host__
__device__
static
constexpr
auto
MakeCGM0N0M1N1M2M3M4N2Descriptor
(
const
CGM0N0M1N1M2N2Desc
&
c_g_m0_n0_m1_n1_m2_n2_desc
)
{
const
auto
G
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I0
);
const
auto
M0
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I1
);
const
auto
N0
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I2
);
const
auto
M1
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I3
);
const
auto
N1
=
c_g_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_g_m0_n0_m1_n1_m2_n2_desc
,
make_tuple
(
make_pass_through_transform
(
G
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
mfma_instr
.
num_groups_per_blk
,
mfma_instr
.
num_input_blks
,
mfma_instr
.
group_size
)),
make_pass_through_transform
(
mfma_instr
.
num_threads_per_blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{},
Sequence
<
8
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk.hpp
View file @
149296c0
...
...
@@ -234,8 +234,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
const
auto
out_gemmg_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmG
constexpr
auto
in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmG
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
...
...
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
View file @
149296c0
...
...
@@ -4,7 +4,7 @@
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v
2r3
.hpp"
#include "gridwise_gemm_xdlops_v
3r1
.hpp"
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
...
...
@@ -22,16 +22,16 @@ template <ck::index_t BlockSize,
ck
::
index_t
K1
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadSliceLengths_
G_
K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_
G_
K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadSliceLengths_
G_
K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_
G_
K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
...
...
@@ -50,9 +50,9 @@ template <ck::index_t BlockSize,
__host__
float
driver_gemm_xdlops_v3r1
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
AK0MK1GridDesc
&
a_
g_
k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_
g_
k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_
g_
m_n_grid_desc
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
...
...
@@ -66,9 +66,10 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
/*
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v
2r3
<BlockSize,
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v
3r1
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
...
...
@@ -84,16 +85,16 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadSliceLengths_
G_
K0_M_K1
,
ABlockTransferThreadClusterLengths_
G_
K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadSliceLengths_
G_
K0_N_K1
,
BBlockTransferThreadClusterLengths_
G_
K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
...
...
@@ -111,28 +112,32 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
CAccessOrderMRepeatNRepeat
>
;
{
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
std
::
cout
<<
"a_g_k0_m_k1_grid_desc{"
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_k0_n_k1_grid_desc{"
<<
b_g_k0_n_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
b_g_k0_n_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
b_g_k0_n_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
b_g_k0_n_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_g_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_g_m_n_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
c_g_m_n_grid_desc
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_g_k0_m_k1_grid_desc
,
b_g_k0_n_k1_grid_desc
,
c_g_m_n_grid_desc
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
const
auto
c_
gemmg_
m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_
g_
m_n_grid_desc
);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
/*
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
...
...
@@ -148,7 +153,7 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
...
...
@@ -162,7 +167,7 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
...
...
@@ -187,7 +192,7 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif
#endif
return ave_time;*/
return
0.0
;
}
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
149296c0
...
...
@@ -85,13 +85,13 @@ int main(int argc, char* argv[])
const
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
19
]);
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
if
(
argc
==
23
){
if
(
argc
==
23
)
{
G
=
std
::
stoi
(
argv
[
22
]);
K
=
K
/
G
;
C
=
C
/
G
;
}
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment