Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fe1a31b0
"tests/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "0b5def5257f53526ef81e1a83d462ef08af628d5"
Commit
fe1a31b0
authored
Sep 21, 2021
by
Chao Liu
Browse files
tweak
parent
6c97007c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
95 additions
and
59 deletions
+95
-59
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+2
-5
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+15
-10
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+58
-30
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+13
-10
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+7
-4
No files found.
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
View file @
fe1a31b0
...
@@ -103,11 +103,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
...
@@ -103,11 +103,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: output tensor
// B: output tensor
const
auto
out_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmn_grid_desc
,
transform_tensor_descriptor
(
out_gemmk_gemmn_grid_desc
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
fe1a31b0
...
@@ -132,8 +132,7 @@ template <index_t BlockSize,
...
@@ -132,8 +132,7 @@ template <index_t BlockSize,
typename
CGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
CAccessOrderMRepeatNRepeat
>
index_t
KBatch
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -174,7 +173,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -174,7 +173,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
CheckValidity
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
)
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
{
// TODO: turn on this
// TODO: turn on this
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
...
@@ -188,6 +188,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -188,6 +188,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
if
(
K0
%
(
KBatch
*
KPerBlock
)
!=
0
)
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
...
@@ -197,7 +202,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -197,7 +202,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
}
__host__
__device__
static
constexpr
index_t
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
...
@@ -208,7 +213,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -208,7 +213,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeABK0MK1GridDescriptor
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
)
MakeABK0MK1GridDescriptor
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
index_t
KBatch
)
{
{
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
...
@@ -226,7 +231,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -226,7 +231,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBBK0NK1GridDescriptor
(
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
)
MakeBBK0NK1GridDescriptor
(
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
index_t
KBatch
)
{
{
const
auto
K0
=
b_k0_n_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
b_k0_n_k1_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
...
@@ -269,7 +274,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -269,7 +274,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
...
@@ -295,10 +300,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -295,10 +300,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
ABK0MK1GridDesc
=
decltype
(
MakeABK0MK1GridDescriptor
(
AK0MK1GridDesc
{}));
using
ABK0MK1GridDesc
=
decltype
(
MakeABK0MK1GridDescriptor
(
AK0MK1GridDesc
{}
,
1
));
using
BBK0NK1GridDesc
=
decltype
(
MakeBBK0NK1GridDescriptor
(
BK0NK1GridDesc
{}));
using
BBK0NK1GridDesc
=
decltype
(
MakeBBK0NK1GridDescriptor
(
BK0NK1GridDesc
{}
,
1
));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}
,
1
));
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
fe1a31b0
...
@@ -25,6 +25,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
...
@@ -25,6 +25,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const
Tensor
<
TInWei
>&
in_n_hi_wi_c
,
const
Tensor
<
TInWei
>&
in_n_hi_wi_c
,
Tensor
<
TInWei
>&
wei_k_y_x_c
,
Tensor
<
TInWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
KBatch
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -49,6 +50,34 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
...
@@ -49,6 +50,34 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 1
#if 1
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -76,8 +105,6 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
...
@@ -76,8 +105,6 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
KBatch
=
32
;
#endif
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
...
@@ -106,14 +133,14 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
...
@@ -106,14 +133,14 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_step_hacks
=
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
@@ -137,7 +164,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
...
@@ -137,7 +164,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{};
std
::
function
<
void
()
>
clear_weight
=
[
&
wei_k_y_x_c_device_buf
,
&
wei_k_y_x_c
]()
{
std
::
function
<
void
()
>
clear_weight
=
[
&
wei_k_y_x_c_device_buf
,
&
wei_k_y_x_c
]()
{
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
...
@@ -163,42 +190,43 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
...
@@ -163,42 +190,43 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
NRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
1
,
3
,
2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
2
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
1
,
3
,
2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
2
,
2
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
7
,
6
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
false
// CAccessOrderMRepeatNRepeat
KBatch
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
KBatch
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
&
clear_weight
);
nrepeat
,
&
clear_weight
);
{
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
View file @
fe1a31b0
...
@@ -46,14 +46,14 @@ template <ck::index_t BlockSize,
...
@@ -46,14 +46,14 @@ template <ck::index_t BlockSize,
typename
CGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
CAccessOrderMRepeatNRepeat
>
ck
::
index_t
KBatch
>
__host__
float
driver_gemm_xdlops_v2r4
(
const
FloatAB
*
p_a_grid
,
__host__
float
driver_gemm_xdlops_v2r4
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
FloatC
*
p_c_grid
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
ck
::
index_t
KBatch
,
AGridStepHacks
,
AGridStepHacks
,
BGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
CGridStepHacks
,
...
@@ -110,8 +110,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -110,8 +110,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
CGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
,
CAccessOrderMRepeatNRepeat
>
;
KBatch
>
;
{
{
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
...
@@ -125,11 +124,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -125,11 +124,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
// const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const
auto
a_b_k0_m_k1_grid_desc
=
GridwiseGemm
::
MakeABK0MK1GridDescriptor
(
a_k0_m_k1_grid_desc
);
const
auto
b_b_k0_n_k1_grid_desc
=
GridwiseGemm
::
MakeBBK0NK1GridDescriptor
(
b_k0_n_k1_grid_desc
);
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
))
const
auto
a_b_k0_m_k1_grid_desc
=
GridwiseGemm
::
MakeABK0MK1GridDescriptor
(
a_k0_m_k1_grid_desc
,
KBatch
);
const
auto
b_b_k0_n_k1_grid_desc
=
GridwiseGemm
::
MakeBBK0NK1GridDescriptor
(
b_k0_n_k1_grid_desc
,
KBatch
);
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
KBatch
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
...
@@ -142,11 +144,12 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -142,11 +144,12 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using
BBK0NK1GridDesc
=
decltype
(
b_b_k0_n_k1_grid_desc
);
using
BBK0NK1GridDesc
=
decltype
(
b_b_k0_n_k1_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
KBatch
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
,
KBatch
);
{
{
std
::
cout
<<
"gridSize : "
<<
grid_size
<<
std
::
endl
;
std
::
cout
<<
"gridSize : "
<<
grid_size
<<
std
::
endl
;
}
}
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
fe1a31b0
...
@@ -18,9 +18,9 @@
...
@@ -18,9 +18,9 @@
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW
1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW
0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC
1
#define USE_CONV_WRW_V4R4R4_XDL_NHWC
0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
1
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1
enum
ConvBackwardWeightAlgo
enum
ConvBackwardWeightAlgo
...
@@ -45,7 +45,7 @@ int main(int argc, char* argv[])
...
@@ -45,7 +45,7 @@ int main(int argc, char* argv[])
#if USE_DYNAMIC_MODE
#if USE_DYNAMIC_MODE
// dynamic mode
// dynamic mode
if
(
argc
!=
2
2
)
if
(
argc
!=
2
3
)
{
{
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
\n
"
);
printf
(
"rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
\n
"
);
...
@@ -76,6 +76,8 @@ int main(int argc, char* argv[])
...
@@ -76,6 +76,8 @@ int main(int argc, char* argv[])
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
const
index_t
k_batch
=
std
::
stoi
(
argv
[
22
]);
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
...
@@ -363,6 +365,7 @@ int main(int argc, char* argv[])
...
@@ -363,6 +365,7 @@ int main(int argc, char* argv[])
in
,
in
,
wei_device
,
wei_device
,
out
,
out
,
k_batch
,
nrepeat
);
nrepeat
);
}
}
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment