Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
059e1c96
Commit
059e1c96
authored
Sep 21, 2021
by
Chao Liu
Browse files
tweak
parent
fe1a31b0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
141 deletions
+101
-141
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+25
-22
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+13
-55
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+44
-42
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+19
-22
No files found.
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
View file @
059e1c96
...
...
@@ -20,7 +20,8 @@ template <typename... In,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
>
index_t
GemmK1Value
,
typename
GemmKBatchType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
...
...
@@ -30,7 +31,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
)
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -64,10 +66,11 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmK
=
N
*
Ho
*
Wo
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmKBatch
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
...
...
@@ -88,30 +91,30 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmm_grid_desc
=
const
auto
in_gemmk
total
_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk
total
_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: output tensor
const
auto
out_gemmk_gemmn_grid_desc
=
const
auto
out_gemmk
total
_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk
total
_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
...
...
@@ -120,8 +123,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
return
make_tuple
(
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
,
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
059e1c96
...
...
@@ -97,8 +97,8 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
A
B
K0MK1GridDesc
,
typename
B
BK0NK1GridDesc
,
typename
CMNGridDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
...
...
@@ -171,33 +171,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
CheckValidity
(
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
// TODO: turn on this
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
if
(
K0
%
(
KBatch
*
KPerBlock
)
!=
0
)
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
K1
==
a_k0_m_k1_grid_desc
.
GetLength
(
I2
)
&&
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
))
&&
K0
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
&&
K1
==
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
&&
K1
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
&&
KBatch
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
);
}
...
...
@@ -212,42 +208,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return
grid_size
;
}
__host__
__device__
static
constexpr
auto
MakeABK0MK1GridDescriptor
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
index_t
KBatch
)
{
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
assert
(
K0
%
KBatch
==
0
);
const
auto
a_b_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k0_m_k1_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
/
KBatch
)),
make_pass_through_transform
(
M
),
make_pass_through_transform
(
K1Value
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
a_b_k0_m_k1_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBBK0NK1GridDescriptor
(
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
index_t
KBatch
)
{
const
auto
K0
=
b_k0_n_k1_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
assert
(
K0
%
KBatch
==
0
);
const
auto
b_b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_k0_n_k1_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
/
KBatch
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
K1Value
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
b_b_k0_n_k1_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
...
...
@@ -300,8 +260,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
using
ABK0MK1GridDesc
=
decltype
(
MakeABK0MK1GridDescriptor
(
AK0MK1GridDesc
{},
1
));
using
BBK0NK1GridDesc
=
decltype
(
MakeBBK0NK1GridDescriptor
(
BK0NK1GridDesc
{},
1
));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
));
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
059e1c96
...
...
@@ -13,7 +13,8 @@ template <typename TInWei,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
,
typename
GemmKBatchType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
...
...
@@ -25,7 +26,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const
Tensor
<
TInWei
>&
in_n_hi_wi_c
,
Tensor
<
TInWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
KBatch
,
GemmKBatchType
Gemm
KBatch
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
...
...
@@ -115,32 +116,33 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
Number
<
GemmK1
>
{},
GemmKBatch
);
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
@@ -160,15 +162,16 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
>
{};
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
std
::
function
<
void
()
>
clear_weight
=
[
&
wei_k_y_x_c_device_buf
,
&
wei_k_y_x_c
]()
{
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r4
<
...
...
@@ -177,8 +180,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
...
...
@@ -207,24 +210,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
,
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
KBatch
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
,
&
clear_weight
);
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
View file @
059e1c96
...
...
@@ -11,8 +11,8 @@ template <ck::index_t BlockSize,
typename
FloatAcc
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
A
B
K0MK1GridDesc
,
typename
B
BK0NK1GridDesc
,
typename
CMNGridDesc
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
...
...
@@ -50,10 +50,9 @@ template <ck::index_t BlockSize,
__host__
float
driver_gemm_xdlops_v2r4
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
A
B
K0MK1GridDesc
&
a_
b_
k0_m_k1_grid_desc
,
const
B
BK0NK1GridDesc
&
b_
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
ck
::
index_t
KBatch
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
...
...
@@ -68,6 +67,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
<
BlockSize
,
...
...
@@ -75,8 +75,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
AK0MK1GridDesc
,
BK0NK1GridDesc
,
A
B
K0MK1GridDesc
,
B
BK0NK1GridDesc
,
CMNGridDesc
,
MPerBlock
,
NPerBlock
,
...
...
@@ -113,25 +113,21 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
CAccessOrderMRepeatNRepeat
>
;
{
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"a_b_k0_m_k1_grid_desc{"
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_k0_n_k1_grid_desc{"
<<
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
b_k0_n_k1_grid_desc
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_b_k0_n_k1_grid_desc{"
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
const
auto
a_b_k0_m_k1_grid_desc
=
GridwiseGemm
::
MakeABK0MK1GridDescriptor
(
a_k0_m_k1_grid_desc
,
KBatch
);
const
auto
b_b_k0_n_k1_grid_desc
=
GridwiseGemm
::
MakeBBK0NK1GridDescriptor
(
b_k0_n_k1_grid_desc
,
KBatch
);
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
KBatch
))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
...
...
@@ -140,10 +136,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
using
ABK0MK1GridDesc
=
decltype
(
a_b_k0_m_k1_grid_desc
);
using
BBK0NK1GridDesc
=
decltype
(
b_b_k0_n_k1_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
KBatch
);
...
...
@@ -153,6 +149,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
{
std
::
cout
<<
"gridSize : "
<<
grid_size
<<
std
::
endl
;
}
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment