Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
4e57b30a
"tests/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "bb25d249f1761668fda52e1e73cc2a4b178e9e87"
Commit
4e57b30a
authored
Aug 11, 2021
by
Chao Liu
Browse files
rename
parent
c03045ce
Changes
30
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
574 additions
and
606 deletions
+574
-606
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+21
-22
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+8
-9
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
...e/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
+8
-9
composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
...lude/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
+20
-20
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
...nel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
+34
-38
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
...nel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
+22
-28
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+17
-18
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+23
-29
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
.../include/tensor_operation/threadwise_tensor_slice_set.hpp
+1
-1
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+104
-110
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
.../tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
+58
-64
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp
...ution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp
+59
-59
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
+52
-52
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
...tion_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
+52
-52
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
...ution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
+20
-20
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
...ackward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
+15
-15
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
...kward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
+15
-15
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp
+15
-15
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
...ion_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
+15
-15
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...on_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+15
-15
No files found.
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
4e57b30a
...
@@ -10,7 +10,7 @@ template <index_t NDimHidden, typename VisibleDimensionIds>
...
@@ -10,7 +10,7 @@ template <index_t NDimHidden, typename VisibleDimensionIds>
struct
TensorCoordinate
;
struct
TensorCoordinate
;
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
TensorCoordinate
I
te
rator
;
struct
TensorCoordinate
S
te
p
;
// Transforms: Tuple<transforms...>
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
...
@@ -252,17 +252,16 @@ struct TensorCoordinate
...
@@ -252,17 +252,16 @@ struct TensorCoordinate
};
};
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
TensorCoordinate
I
te
rator
struct
TensorCoordinate
S
te
p
{
{
// TODO make these private
// TODO make these private
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
public:
public:
__host__
__device__
constexpr
TensorCoordinate
I
te
rator
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
S
te
p
()
=
default
;
__host__
__host__
__device__
constexpr
TensorCoordinateStep
(
const
VisibleIndex
&
idx_diff_visible
,
__device__
constexpr
TensorCoordinateIterator
(
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
const
MultiIndex
<
NTransform
>&
do_transforms
)
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
{
{
}
}
...
@@ -423,8 +422,9 @@ __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tens
...
@@ -423,8 +422,9 @@ __host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tens
// UpdateLowerIndexHack: Sequence<...>
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
// HACK: control UpdateLowerIndex
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
__host__
__device__
constexpr
auto
make_tensor_coordinate_iterator
(
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
{
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
"wrong! # of dimension inconsistent"
);
...
@@ -471,24 +471,24 @@ __host__ __device__ constexpr auto make_tensor_coordinate_iterator(
...
@@ -471,24 +471,24 @@ __host__ __device__ constexpr auto make_tensor_coordinate_iterator(
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
});
return
TensorCoordinate
I
te
rator
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
return
TensorCoordinate
S
te
p
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
idx_diff_visible
,
idx_diff_visible
,
do_transforms
};
do_transforms
};
}
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
make_tensor_coordinate_iterator
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
const
VisibleIndex
&
idx_diff_visible
)
{
{
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
return
make_tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
}
}
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoord
I
te
rator
>
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoord
S
te
p
>
__host__
__device__
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
__host__
__device__
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
TensorCoord
&
coord
,
const
TensorCoord
I
te
rator
&
coord_
i
te
rator
)
const
TensorCoord
S
te
p
&
coord_
s
te
p
)
{
{
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
...
@@ -497,9 +497,8 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
...
@@ -497,9 +497,8 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
// initialize visible index diff
// initialize visible index diff
set_container_subset
(
idx_diff_hidden
,
set_container_subset
(
TensorDesc
::
GetVisibleDimensionIds
(),
idx_diff_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
coord_step
.
GetVisibleIndexDiff
());
coord_iterator
.
GetVisibleIndexDiff
());
// this is what needs to be updated
// this is what needs to be updated
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
...
@@ -508,13 +507,13 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
...
@@ -508,13 +507,13 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
auto
idx_hidden_pick_visible
=
auto
idx_hidden_pick_visible
=
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
idx_hidden_pick_visible
+=
coord_
i
te
rator
.
GetIndexDiff
();
idx_hidden_pick_visible
+=
coord_
s
te
p
.
GetIndexDiff
();
set_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
idx_hidden_pick_visible
);
set_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
idx_hidden_pick_visible
);
// update rest of hidden index
// update rest of hidden index
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
if
(
coord_
i
te
rator
.
do_transforms_
[
itran
])
if
(
coord_
s
te
p
.
do_transforms_
[
itran
])
{
{
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
...
@@ -527,7 +526,7 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
...
@@ -527,7 +526,7 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
// HACK: control UpdateLowerIndex for Merge using hack
// HACK: control UpdateLowerIndex for Merge using hack
constexpr
index_t
Hack
=
decltype
(
coord_
i
te
rator
.
update_lower_index_hack_
)
::
At
(
itran
);
constexpr
index_t
Hack
=
decltype
(
coord_
s
te
p
.
update_lower_index_hack_
)
::
At
(
itran
);
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
...
@@ -591,7 +590,7 @@ using TensorCoordinate_t = decltype(make_tensor_coordinate(
...
@@ -591,7 +590,7 @@ using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
template
<
typename
TensorDesc
>
using
TensorCoordinate
I
te
rator
_t
=
decltype
(
make_tensor_coordinate_
i
te
rator
(
using
TensorCoordinate
S
te
p
_t
=
decltype
(
make_tensor_coordinate_
s
te
p
(
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv_t
<
remove_reference_t
<
TensorDesc
>>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
4e57b30a
...
@@ -77,15 +77,14 @@ struct BlockwiseTensorSliceTransfer_v4
...
@@ -77,15 +77,14 @@ struct BlockwiseTensorSliceTransfer_v4
}
}
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
}
}
...
@@ -118,18 +117,18 @@ struct BlockwiseTensorSliceTransfer_v4
...
@@ -118,18 +117,18 @@ struct BlockwiseTensorSliceTransfer_v4
}
}
}
}
// SrcMoveSliceWindow
I
te
rator
Hack to control index calculation move slice window
// SrcMoveSliceWindow
S
te
p
Hack to control index calculation move slice window
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
,
const
Index
&
step
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
step
,
src_move_slice_window_
s
te
p
_hack
);
}
}
}
}
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp
View file @
4e57b30a
...
@@ -75,15 +75,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -75,15 +75,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
}
}
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
}
}
...
@@ -106,18 +105,18 @@ struct BlockwiseTensorSliceTransfer_v4r1
...
@@ -106,18 +105,18 @@ struct BlockwiseTensorSliceTransfer_v4r1
}
}
}
}
// SrcMoveSliceWindow
I
te
rator
Hack to control index calculation move slice window
// SrcMoveSliceWindow
S
te
p
Hack to control index calculation move slice window
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
,
const
Index
&
step
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
step
,
src_move_slice_window_
s
te
p
_hack
);
}
}
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp
View file @
4e57b30a
...
@@ -84,11 +84,11 @@ template <index_t BlockSize,
...
@@ -84,11 +84,11 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
struct
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -496,9 +496,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
...
@@ -496,9 +496,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
a_blockwise_copy
.
RunWrite
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_even_buf
);
...
@@ -515,18 +515,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
...
@@ -515,18 +515,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
blockwise_gemm
.
Run
(
c_thread_desc_bm0_bm1_bn0_bn1
,
...
@@ -541,18 +541,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
...
@@ -541,18 +541,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -571,18 +571,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
...
@@ -571,18 +571,18 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -650,7 +650,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
...
@@ -650,7 +650,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
c_thread_buf
,
c_thread_buf
,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp
View file @
4e57b30a
...
@@ -145,11 +145,11 @@ template <index_t BlockSize,
...
@@ -145,11 +145,11 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
GridwiseGemmDlops_km_kn_mn_v1r2
struct
GridwiseGemmDlops_km_kn_mn_v1r2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -475,15 +475,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
...
@@ -475,15 +475,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m0_m1_global_
i
te
rator
_hacks
=
AGrid
I
te
rator
Hacks
{};
constexpr
auto
a_k_m0_m1_global_
s
te
p
_hacks
=
AGrid
S
te
p
Hacks
{};
constexpr
auto
b_k_n0_n1_global_
i
te
rator
_hacks
=
BGrid
I
te
rator
Hacks
{};
constexpr
auto
b_k_n0_n1_global_
s
te
p
_hacks
=
BGrid
S
te
p
Hacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_k_m0_m1_global_move_slice_window_
i
te
rator
_hack
=
constexpr
auto
a_k_m0_m1_global_move_slice_window_
s
te
p
_hack
=
AGridMoveSliceWindow
I
te
rator
Hacks
{};
AGridMoveSliceWindow
S
te
p
Hacks
{};
constexpr
auto
b_k_n0_n1_global_move_slice_window_
i
te
rator
_hack
=
constexpr
auto
b_k_n0_n1_global_move_slice_window_
s
te
p
_hack
=
BGridMoveSliceWindow
I
te
rator
Hacks
{};
BGridMoveSliceWindow
S
te
p
Hacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block_double
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
p_a_block_double
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
...
@@ -500,9 +500,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
...
@@ -500,9 +500,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m0_m1_block_desc
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k_m0_m1_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_even_buf
);
...
@@ -517,22 +517,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
...
@@ -517,22 +517,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
do
do
{
{
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_step_hack
);
a_k_m0_m1_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
b_k_n0_n1_grid_desc
,
b_k_n0_n1_global_move_slice_window_step_hack
);
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_iterator_hack
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
@@ -545,22 +543,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
...
@@ -545,22 +543,20 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n0_n1_block_desc
,
b_block_odd_buf
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_step_hack
);
a_k_m0_m1_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
b_k_n0_n1_grid_desc
,
b_k_n0_n1_global_move_slice_window_step_hack
);
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_iterator_hack
);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -579,18 +575,18 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
...
@@ -579,18 +575,18 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m0_m1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k_m0_m1_global_move_slice_window_
i
te
rator
_hack
);
a_k_m0_m1_global_move_slice_window_
s
te
p
_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n0_n1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k_n0_n1_global_move_slice_window_
i
te
rator
_hack
);
b_k_n0_n1_global_move_slice_window_
s
te
p
_hack
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
i
te
rator
_hacks
);
a_k_m0_m1_grid_desc
,
a_global_buf
,
a_k_m0_m1_global_
s
te
p
_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
i
te
rator
_hacks
);
b_k_n0_n1_grid_desc
,
b_global_buf
,
b_k_n0_n1_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -657,7 +653,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
...
@@ -657,7 +653,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
c_thread_buf
,
c_thread_buf
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
View file @
4e57b30a
...
@@ -141,11 +141,11 @@ template <index_t BlockSize,
...
@@ -141,11 +141,11 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
struct
GridwiseGemmDlops_km_kn_mn_v1r3
struct
GridwiseGemmDlops_km_kn_mn_v1r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -494,8 +494,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -494,8 +494,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1_block_desc
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1_block_desc
,
b_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1_block_desc
,
b_block_even_buf
);
...
@@ -514,18 +514,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -514,18 +514,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{});
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridIteratorHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridIteratorHacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
@@ -540,18 +538,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -540,18 +538,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
AGridMoveSliceWindow
I
te
rator
Hacks
{});
AGridMoveSliceWindow
S
te
p
Hacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
BGridMoveSliceWindow
I
te
rator
Hacks
{});
BGridMoveSliceWindow
S
te
p
Hacks
{});
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{});
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridIteratorHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridIteratorHacks
{});
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -568,18 +564,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -568,18 +564,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{});
AGridMoveSliceWindowIteratorHacks
{});
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{});
b_block_slice_copy_step
,
BGridMoveSliceWindowIteratorHacks
{});
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
I
te
rator
Hacks
{});
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGrid
S
te
p
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
I
te
rator
Hacks
{});
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGrid
S
te
p
Hacks
{});
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
blockwise_gemm
.
Run
(
...
@@ -647,7 +641,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -647,7 +641,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
c_thread_buf
,
c_thread_buf
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_grid_buf
,
c_grid_buf
,
CGrid
I
te
rator
Hacks
{});
CGrid
S
te
p
Hacks
{});
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
4e57b30a
...
@@ -42,11 +42,11 @@ template <index_t BlockSize,
...
@@ -42,11 +42,11 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobal
I
te
rator
Hacks
,
typename
AGlobal
S
te
p
Hacks
,
typename
BGlobal
I
te
rator
Hacks
,
typename
BGlobal
S
te
p
Hacks
,
typename
CGlobal
I
te
rator
Hacks
,
typename
CGlobal
S
te
p
Hacks
,
typename
AGlobalMoveSliceWindow
I
te
rator
Hacks
,
typename
AGlobalMoveSliceWindow
S
te
p
Hacks
,
typename
BGlobalMoveSliceWindow
I
te
rator
Hacks
>
typename
BGlobalMoveSliceWindow
S
te
p
Hacks
>
struct
GridwiseGemmDlops_km_kn_mn_v3
struct
GridwiseGemmDlops_km_kn_mn_v3
{
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
...
@@ -239,15 +239,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -239,15 +239,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_e_k_global_
i
te
rator
_hacks
=
AGlobal
I
te
rator
Hacks
{};
constexpr
auto
a_e_k_global_
s
te
p
_hacks
=
AGlobal
S
te
p
Hacks
{};
constexpr
auto
b_e_n_ho_wo_global_
i
te
rator
_hacks
=
BGlobal
I
te
rator
Hacks
{};
constexpr
auto
b_e_n_ho_wo_global_
s
te
p
_hacks
=
BGlobal
S
te
p
Hacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_e_k_global_move_slice_window_iterator_hack
=
constexpr
auto
a_e_k_global_move_slice_window_step_hack
=
AGlobalMoveSliceWindowStepHacks
{};
AGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_step_hack
=
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowStepHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
...
@@ -257,14 +256,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -257,14 +256,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: preload data
// LDS double buffer: preload data
{
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global_buf
,
a_e_k_global_
i
te
rator
_hacks
);
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global_buf
,
a_e_k_global_
s
te
p
_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_global_buf
,
b_global_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
a_block_buf
);
}
}
...
@@ -288,7 +287,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -288,7 +287,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
...
@@ -304,7 +303,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -304,7 +303,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
...
@@ -327,7 +326,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -327,7 +326,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_
i
te
rator
_hacks
);
b_e_n_ho_wo_global_
s
te
p
_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
...
@@ -346,7 +345,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -346,7 +345,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// output: register to global memory
// output: register to global memory
{
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr
auto
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
=
CGlobal
I
te
rator
Hacks
{};
constexpr
auto
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
=
CGlobal
S
te
p
Hacks
{};
const
index_t
k_thread_data_on_global
=
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
k_block_data_on_global
+
k_thread_id
*
KPerThread
;
...
@@ -370,7 +369,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -370,7 +369,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
c_thread_buf
,
c_thread_buf
,
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
c_global_buf
,
c_global_buf
,
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
);
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
);
}
}
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
4e57b30a
...
@@ -126,11 +126,11 @@ template <index_t BlockSize,
...
@@ -126,11 +126,11 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
,
typename
BGridMoveSliceWindow
S
te
p
Hacks
,
bool
CAccessOrderMRepeatNRepeat
>
bool
CAccessOrderMRepeatNRepeat
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
{
...
@@ -416,15 +416,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -416,15 +416,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_
i
te
rator
_hacks
=
AGrid
I
te
rator
Hacks
{};
constexpr
auto
a_k0_m_k1_grid_
s
te
p
_hacks
=
AGrid
S
te
p
Hacks
{};
constexpr
auto
b_k0_n_k1_grid_
i
te
rator
_hacks
=
BGrid
I
te
rator
Hacks
{};
constexpr
auto
b_k0_n_k1_grid_
s
te
p
_hacks
=
BGrid
S
te
p
Hacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_k0_m_k1_grid_move_slice_window_iterator_hack
=
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
AGridMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_iterator_hack
=
BGridMoveSliceWindowIteratorHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
...
@@ -433,10 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -433,10 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// preload data into LDS
// preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
...
@@ -449,18 +445,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -449,18 +445,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_
i
te
rator
_hack
);
a_k0_m_k1_grid_move_slice_window_
s
te
p
_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_
i
te
rator
_hack
);
b_k0_n_k1_grid_move_slice_window_
s
te
p
_hack
);
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_iterator_hacks
);
block_sync_lds
();
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_iterator_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
...
@@ -526,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -526,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid =
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks = CGrid
I
te
rator
Hacks{};
constexpr auto c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks = CGrid
S
te
p
Hacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
...
@@ -557,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -557,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_blk_buf_,
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks);
}
}
#else
#else
{
{
...
@@ -579,7 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -579,7 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
index_t
n_thread_data_on_grid
=
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
=
CGrid
I
te
rator
Hacks
{};
constexpr
auto
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
=
CGrid
S
te
p
Hacks
{};
auto
c_thread_copy
=
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
...
@@ -610,7 +604,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -610,7 +604,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
return
c_thread_idx_
;
return
c_thread_idx_
;
};
};
...
@@ -625,7 +619,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -625,7 +619,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
};
};
auto
nrepeat_plus_copy
=
[
&
](
auto
c_thread_idx_
)
{
auto
nrepeat_plus_copy
=
[
&
](
auto
c_thread_idx_
)
{
...
@@ -638,7 +632,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -638,7 +632,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
};
};
auto
mrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
auto
mrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
...
@@ -651,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -651,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
};
};
auto
nrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
auto
nrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
...
@@ -664,7 +658,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -664,7 +658,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_m1_m2_n_grid_desc
,
c_m0_m1_m2_n_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c_m0_m1_m2_n_grid_tensor_
i
te
rator
_hacks
);
c_m0_m1_m2_n_grid_tensor_
s
te
p
_hacks
);
};
};
static_assert
((
MRepeat
==
4
&&
NRepeat
==
4
)
or
(
MRepeat
==
4
&&
NRepeat
==
2
)
or
static_assert
((
MRepeat
==
4
&&
NRepeat
==
4
)
or
(
MRepeat
==
4
&&
NRepeat
==
2
)
or
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
View file @
4e57b30a
...
@@ -11,7 +11,7 @@ namespace ck {
...
@@ -11,7 +11,7 @@ namespace ck {
// 1. Desc is known at compile-time
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 3. OriginIdx is known at compile-time
// 4. use #-
i
te
rator
// 4. use #-
s
te
p
template
<
typename
Data
,
template
<
typename
Data
,
typename
Desc
,
typename
Desc
,
typename
SliceLengths
,
typename
SliceLengths
,
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
4e57b30a
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
View file @
4e57b30a
...
@@ -41,8 +41,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -41,8 +41,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
SrcCoord
I
te
rator
=
decltype
(
make_tensor_coordinate_
i
te
rator
(
SrcDesc
{},
Index
{}));
using
SrcCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
SrcDesc
{},
Index
{}));
using
DstCoord
I
te
rator
=
decltype
(
make_tensor_coordinate_
i
te
rator
(
DstDesc
{},
Index
{}));
using
DstCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
Index
&
src_slice_origin
,
...
@@ -72,10 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -72,10 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
}
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
const
SrcBuffer
&
src_buf
,
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
...
@@ -108,31 +107,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -108,31 +107,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
ordered_src_access_lengths
=
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// make forward
i
te
rator
s
// make forward
s
te
p
s
const
auto
src_forward_
i
te
rator
s
=
generate_tuple
(
const
auto
src_forward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
forward_step
;
Index
forward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_vector_tensor_lengths
[
i
]
:
0
;
forward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_vector_tensor_lengths
[
i
]
:
0
;
});
});
return
make_tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
src_desc
,
forward_step
,
src_
i
te
rator
_hacks
[
I0
][
i
]);
src_desc
,
forward_step
_idx
,
src_
s
te
p
_hacks
[
I0
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward
i
te
rator
s
// make backward
s
te
p
s
const
auto
src_backward_
i
te
rator
s
=
generate_tuple
(
const
auto
src_backward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
backward_step
;
Index
backward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_vector_tensor_lengths
[
i
]
:
0
;
backward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_vector_tensor_lengths
[
i
]
:
0
;
});
});
return
make_tensor_coordinate_
i
te
rator
(
return
make_tensor_coordinate_
s
te
p
(
src_desc
,
backward_step
,
src_
i
te
rator
_hacks
[
I1
][
i
]);
src_desc
,
backward_step
_idx
,
src_
s
te
p
_hacks
[
I1
][
i
]);
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
...
@@ -220,12 +219,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -220,12 +219,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_
i
te
rator
s
[
src_dim_access_order
[
i
]]);
src_desc
,
src_coord_
,
src_forward_
s
te
p
s
[
src_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_tensor_coordinate
(
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_
i
te
rator
s
[
src_dim_access_order
[
i
]]);
src_desc
,
src_coord_
,
src_backward_
s
te
p
s
[
src_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -234,17 +233,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -234,17 +233,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// move src coordinate back to slice origin (or not)
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
{
const
auto
src_reset_
i
te
rator
=
const
auto
src_reset_
s
te
p
=
make_tensor_coordinate_
i
te
rator
(
src_desc
,
GetSrcCoordinateResetStep
());
make_tensor_coordinate_
s
te
p
(
src_desc
,
GetSrcCoordinateResetStep
());
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_
i
te
rator
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_
s
te
p
);
}
}
}
}
template
<
typename
DstBuffer
,
typename
DstIteratorHacks
>
template
<
typename
DstBuffer
,
typename
DstStepHacks
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
__device__
void
DstBuffer
&
dst_buf
,
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
DstStepHacks
&
dst_step_hacks
)
const
DstIteratorHacks
&
dst_iterator_hacks
)
{
{
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
...
@@ -277,35 +275,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -277,35 +275,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
ordered_dst_access_lengths
=
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward
i
te
rator
s
// make forward
s
te
p
s
const
auto
dst_forward_
i
te
rator
s
=
generate_tuple
(
const
auto
dst_forward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
forward_step
;
Index
forward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_vector_tensor_lengths
[
i
]
:
0
;
forward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_vector_tensor_lengths
[
i
]
:
0
;
});
});
const
auto
forward_iterator
=
make_tensor_coordinate_iterator
(
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step
,
dst_iterator_hacks
[
I0
][
i
]);
dst_desc
,
forward_step_idx
,
dst_step_hacks
[
I0
][
i
]);
return
forward_iterator
;
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward
i
te
rator
s
// make backward
s
te
p
s
const
auto
dst_backward_
i
te
rator
s
=
generate_tuple
(
const
auto
dst_backward_
s
te
p
s
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
Index
backward_step
;
Index
backward_step
_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_vector_tensor_lengths
[
i
]
:
0
;
backward_step
_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_vector_tensor_lengths
[
i
]
:
0
;
});
});
const
auto
backward_iterator
=
make_tensor_coordinate_iterator
(
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step
,
dst_iterator_hacks
[
I1
][
i
]);
dst_desc
,
backward_step_idx
,
dst_step_hacks
[
I1
][
i
]);
return
backward_iterator
;
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
...
@@ -395,12 +389,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -395,12 +389,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_
i
te
rator
s
[
dst_dim_access_order
[
i
]]);
dst_desc
,
dst_coord_
,
dst_forward_
s
te
p
s
[
dst_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_
i
te
rator
s
[
dst_dim_access_order
[
i
]]);
dst_desc
,
dst_coord_
,
dst_backward_
s
te
p
s
[
dst_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -409,10 +403,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -409,10 +403,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// move dst coordinate back to slice origin (or not)
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
{
const
auto
dst_reset_
i
te
rator
=
const
auto
dst_reset_
s
te
p
=
make_tensor_coordinate_
i
te
rator
(
dst_desc
,
GetDstCoordinateResetStep
());
make_tensor_coordinate_
s
te
p
(
dst_desc
,
GetDstCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_
i
te
rator
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_
s
te
p
);
}
}
}
}
...
@@ -423,11 +417,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -423,11 +417,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
src_
i
te
rator
_hacks
=
constexpr
auto
src_
s
te
p
_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunRead
(
src_desc
,
src_buf
,
src_
i
te
rator
_hacks
);
RunRead
(
src_desc
,
src_buf
,
src_
s
te
p
_hacks
);
}
}
template
<
typename
DstBuffer
>
template
<
typename
DstBuffer
>
...
@@ -437,11 +431,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -437,11 +431,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
dst_
i
te
rator
_hacks
=
constexpr
auto
dst_
s
te
p
_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunWrite
(
dst_desc
,
dst_buf
,
dst_
i
te
rator
_hacks
);
RunWrite
(
dst_desc
,
dst_buf
,
dst_
s
te
p
_hacks
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
@@ -564,17 +558,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -564,17 +558,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_
i
te
rator
(
src_desc
,
adjusted_step_idx
);
const
auto
adjusted_step
=
make_tensor_coordinate_
s
te
p
(
src_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
typename
SrcMoveSliceWindow
I
te
rator
Hack
>
template
<
typename
SrcMoveSliceWindow
S
te
p
Hack
>
__device__
void
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
,
const
Index
&
src_slice_origin_step_idx
,
const
SrcMoveSliceWindow
I
te
rator
Hack
&
src_move_slice_window_
i
te
rator
_hack
)
const
SrcMoveSliceWindow
S
te
p
Hack
&
src_move_slice_window_
s
te
p
_hack
)
{
{
// if src coord was not reset by RunRead(), then need to adjust the step here
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
const
auto
adjusted_step_idx
=
...
@@ -582,8 +576,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -582,8 +576,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_
i
te
rator
(
const
auto
adjusted_step
=
make_tensor_coordinate_
s
te
p
(
src_desc
,
adjusted_step_idx
,
src_move_slice_window_
i
te
rator
_hack
);
src_desc
,
adjusted_step_idx
,
src_move_slice_window_
s
te
p
_hack
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
...
@@ -597,7 +591,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -597,7 +591,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_
i
te
rator
(
dst_desc
,
adjusted_step_idx
);
const
auto
adjusted_step
=
make_tensor_coordinate_
s
te
p
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
}
...
@@ -620,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -620,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// 2. SrcBuffer is DynamicBuffer
// 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time
// 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-
i
te
rator
// 5. use #-
s
te
p
// 2. dst:
// 2. dst:
// 1. DstDesc is known at compile-time
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 2. DstBuffer is StaticBuffer
...
@@ -649,7 +643,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
...
@@ -649,7 +643,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
I
te
rator
=
decltype
(
make_tensor_coordinate_
i
te
rator
(
SrcDesc
{},
Index
{}));
using
SrcCoord
S
te
p
=
decltype
(
make_tensor_coordinate_
s
te
p
(
SrcDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4r1
(
const
Index
&
src_ref_idx
)
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4r1
(
const
Index
&
src_ref_idx
)
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
...
@@ -732,12 +726,12 @@ struct ThreadwiseTensorSliceTransfer_v4r1
...
@@ -732,12 +726,12 @@ struct ThreadwiseTensorSliceTransfer_v4r1
constexpr
auto
src_ref_to_data_disp_idx
=
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_
i
te
rator
=
constexpr
auto
src_ref_to_data_disp_coord_
s
te
p
=
make_tensor_coordinate_
i
te
rator
(
src_desc
,
src_ref_to_data_disp_idx
);
make_tensor_coordinate_
s
te
p
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
auto
src_data_coord
=
src_ref_coord_
;
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_
i
te
rator
);
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_
s
te
p
);
vector_type_maker_t
<
SrcData
,
src_vector_desc
.
GetElementSpaceSize
()
>
src_vector
;
vector_type_maker_t
<
SrcData
,
src_vector_desc
.
GetElementSpaceSize
()
>
src_vector
;
...
@@ -773,7 +767,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
...
@@ -773,7 +767,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
src_desc
=
SrcDesc
{};
const
auto
src_slice_move_step_iter
=
const
auto
src_slice_move_step_iter
=
make_tensor_coordinate_
i
te
rator
(
src_desc
,
to_multi_index
(
src_slice_move_step_idx
));
make_tensor_coordinate_
s
te
p
(
src_desc
,
to_multi_index
(
src_slice_move_step_idx
));
move_tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
move_tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
}
}
...
...
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp
View file @
4e57b30a
...
@@ -113,16 +113,16 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
...
@@ -113,16 +113,16 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
using
BKNGridDesc
=
decltype
(
b_k_n_grid_desc
);
using
BKNGridDesc
=
decltype
(
b_k_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
AGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
BGrid
I
te
rator
Hacks
=
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
...
@@ -130,21 +130,21 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
...
@@ -130,21 +130,21 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemmDlops_km_kn_mn_v1r2
<
BlockSize
,
GridwiseGemmDlops_km_kn_mn_v1r2
<
BlockSize
,
...
@@ -184,11 +184,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
...
@@ -184,11 +184,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
BGridMoveSliceWindow
S
te
p
Hacks
>
;
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
auto
b_k_n0_n1_grid_desc
=
GridwiseGemm
::
MakeBKN0N1GridDescriptor
(
b_k_n_grid_desc
);
auto
b_k_n0_n1_grid_desc
=
GridwiseGemm
::
MakeBKN0N1GridDescriptor
(
b_k_n_grid_desc
);
...
@@ -249,16 +249,16 @@ extern "C" __global__ void
...
@@ -249,16 +249,16 @@ extern "C" __global__ void
using
BKNGridDesc
=
decltype
(
b_k_n_grid_desc
);
using
BKNGridDesc
=
decltype
(
b_k_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
AGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
BGrid
I
te
rator
Hacks
=
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
...
@@ -266,21 +266,21 @@ extern "C" __global__ void
...
@@ -266,21 +266,21 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemmDlops_km_kn_mn_v1r2
<
BlockSize
,
GridwiseGemmDlops_km_kn_mn_v1r2
<
BlockSize
,
...
@@ -320,11 +320,11 @@ extern "C" __global__ void
...
@@ -320,11 +320,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
BGridMoveSliceWindow
S
te
p
Hacks
>
;
constexpr
auto
a_k_m0_m1_grid_desc_tmp
=
constexpr
auto
a_k_m0_m1_grid_desc_tmp
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
...
...
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
View file @
4e57b30a
...
@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
...
@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc
);
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
AGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
BGrid
I
te
rator
Hacks
=
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
...
@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
...
@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
...
@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
,
false
>
;
false
>
;
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
...
@@ -243,12 +243,12 @@ extern "C" __global__ void
...
@@ -243,12 +243,12 @@ extern "C" __global__ void
constexpr
auto
b_k0_n_k1_grid_desc_tmp
=
descs
[
I1
];
constexpr
auto
b_k0_n_k1_grid_desc_tmp
=
descs
[
I1
];
constexpr
auto
c_m_n_grid_desc
=
descs
[
I2
];
constexpr
auto
c_m_n_grid_desc
=
descs
[
I2
];
using
AGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
BGrid
I
te
rator
Hacks
=
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -256,25 +256,25 @@ extern "C" __global__ void
...
@@ -256,25 +256,25 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
AK0MK1GridDesc
=
decltype
(
a_k0_m_k1_grid_desc_tmp
);
using
AK0MK1GridDesc
=
decltype
(
a_k0_m_k1_grid_desc_tmp
);
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc_tmp
);
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc_tmp
);
...
@@ -316,11 +316,11 @@ extern "C" __global__ void
...
@@ -316,11 +316,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
,
false
>
;
false
>
;
constexpr
auto
c_m0_m1_m2_n_grid_desc_tmp
=
constexpr
auto
c_m0_m1_m2_n_grid_desc_tmp
=
...
...
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
View file @
4e57b30a
...
@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
...
@@ -110,12 +110,12 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc
);
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
BGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
AGrid
I
te
rator
Hacks
=
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
...
@@ -123,25 +123,25 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
...
@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
...
@@ -179,11 +179,11 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
,
false
>
;
false
>
;
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
...
@@ -247,12 +247,12 @@ extern "C" __global__ void
...
@@ -247,12 +247,12 @@ extern "C" __global__ void
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc_tmp
);
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc_tmp
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
BGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
AGrid
I
te
rator
Hacks
=
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -260,25 +260,25 @@ extern "C" __global__ void
...
@@ -260,25 +260,25 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
...
@@ -316,11 +316,11 @@ extern "C" __global__ void
...
@@ -316,11 +316,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
,
false
>
;
false
>
;
constexpr
auto
c_m0_m1_m2_n_grid_desc_tmp
=
constexpr
auto
c_m0_m1_m2_n_grid_desc_tmp
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
...
...
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
View file @
4e57b30a
...
@@ -111,7 +111,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
...
@@ -111,7 +111,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
using
BGridDesc_GK0_GN0_GN1_GK1
=
decltype
(
b_grid_desc_gk0_gn0_gn1_gk1
);
using
BGridDesc_GK0_GN0_GN1_GK1
=
decltype
(
b_grid_desc_gk0_gn0_gn1_gk1
);
using
CGridDesc_GM0_GM1_GN0_GN1
=
decltype
(
c_grid_desc_gm0_gm1_gn0_gn1
);
using
CGridDesc_GM0_GM1_GN0_GN1
=
decltype
(
c_grid_desc_gm0_gm1_gn0_gn1
);
using
AGrid
I
te
rator
Hacks
=
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
...
@@ -123,7 +123,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
...
@@ -123,7 +123,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
using
BGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
...
@@ -135,7 +135,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
...
@@ -135,7 +135,7 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
...
@@ -151,9 +151,9 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
...
@@ -151,9 +151,9 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{})));
// 5-: GN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{})));
// 5-: GN1
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
;
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
;
using
GridwiseContraction
=
using
GridwiseContraction
=
...
@@ -191,11 +191,11 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
...
@@ -191,11 +191,11 @@ convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
BGridMoveSliceWindow
S
te
p
Hacks
>
;
if
(
get_block_1d_id
()
==
0
&&
get_thread_local_1d_id
()
==
0
)
if
(
get_block_1d_id
()
==
0
&&
get_thread_local_1d_id
()
==
0
)
{
{
...
@@ -254,7 +254,7 @@ extern "C" __global__ void
...
@@ -254,7 +254,7 @@ extern "C" __global__ void
using
BGridDesc_GK0_GN0_GN1_GK1
=
decltype
(
b_grid_desc_gk0_gn0_gn1_gk1
);
using
BGridDesc_GK0_GN0_GN1_GK1
=
decltype
(
b_grid_desc_gk0_gn0_gn1_gk1
);
using
CGridDesc_GM0_GM1_GN0_GN1
=
decltype
(
c_grid_desc_gm0_gm1_gn0_gn1
);
using
CGridDesc_GM0_GM1_GN0_GN1
=
decltype
(
c_grid_desc_gm0_gm1_gn0_gn1
);
using
AGrid
I
te
rator
Hacks
=
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
...
@@ -266,7 +266,7 @@ extern "C" __global__ void
...
@@ -266,7 +266,7 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
using
BGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
...
@@ -278,7 +278,7 @@ extern "C" __global__ void
...
@@ -278,7 +278,7 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
...
@@ -294,9 +294,9 @@ extern "C" __global__ void
...
@@ -294,9 +294,9 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{})));
// 5-: GN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{})));
// 5-: GN1
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
;
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
;
using
GridwiseContraction
=
using
GridwiseContraction
=
...
@@ -334,11 +334,11 @@ extern "C" __global__ void
...
@@ -334,11 +334,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
BGridMoveSliceWindow
S
te
p
Hacks
>
;
using
AGridDesc_GK0_GM0_GM10_GM11_GK1
=
using
AGridDesc_GK0_GM0_GM10_GM11_GK1
=
decltype
(
GridwiseContraction
::
MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1
(
decltype
(
GridwiseContraction
::
MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1
(
...
...
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
View file @
4e57b30a
...
@@ -207,7 +207,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
...
@@ -207,7 +207,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
...
@@ -215,7 +215,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
...
@@ -215,7 +215,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
...
@@ -223,7 +223,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
...
@@ -223,7 +223,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
constexpr
auto
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
in_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 1+: NRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 1+: NRepeat
...
@@ -243,10 +243,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
...
@@ -243,10 +243,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 7-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 7-: N1
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
...
@@ -287,11 +287,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
...
@@ -287,11 +287,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence
<
1
,
3
,
7
,
0
,
2
,
4
,
5
,
6
>
,
Sequence
<
1
,
3
,
7
,
0
,
2
,
4
,
5
,
6
>
,
6
,
6
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
in_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
...
@@ -299,11 +299,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
...
@@ -299,11 +299,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
wei_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
in_gemmm_gemmn_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
out_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
out_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
in_m0_m1_m2_n_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
...
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
View file @
4e57b30a
...
@@ -179,7 +179,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -179,7 +179,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
...
@@ -187,7 +187,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -187,7 +187,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
...
@@ -195,7 +195,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -195,7 +195,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
constexpr
auto
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
in_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: NRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: NRepeat
...
@@ -215,10 +215,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -215,10 +215,10 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N1
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
...
@@ -263,11 +263,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -263,11 +263,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
#endif
#endif
7
,
7
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
in_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
true
// CAccessOrderMRepeatNRepeat
true
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
...
@@ -275,11 +275,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -275,11 +275,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
out_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
in_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
out_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
in_m0_m1_m2_n_grid_
s
te
p
_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp
View file @
4e57b30a
...
@@ -89,7 +89,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -89,7 +89,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
in_right_pads
);
in_right_pads
);
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk_gemmm0_gemmn1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk_gemmm0_gemmn1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -99,7 +99,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -99,7 +99,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
...
@@ -107,7 +107,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -107,7 +107,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
=
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -121,10 +121,10 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -121,10 +121,10 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
const
auto
wei_gemmk_gemmm_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk_gemmm_grid_desc
=
descs
[
I0
];
...
@@ -171,22 +171,22 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -171,22 +171,22 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
5
,
// CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11
,
GemmCThreadTransferDstScalarPerVector_N11
,
decltype
(
wei_gemmk_gemmm0_gemmn1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk_gemmm0_gemmn1_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk_gemmn0_gemmn1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk_gemmn0_gemmn1_grid_
s
te
p
_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
i
te
rator
_hacks
)
>
(
decltype
(
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
s
te
p
_hacks
)
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
wei_gemmk_gemmm_grid_desc
,
wei_gemmk_gemmm_grid_desc
,
in_gemmk_gemmn_grid_desc
,
in_gemmk_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
wei_gemmk_gemmm0_gemmn1_grid_
i
te
rator
_hacks
,
wei_gemmk_gemmm0_gemmn1_grid_
s
te
p
_hacks
,
in_gemmk_gemmn0_gemmn1_grid_
i
te
rator
_hacks
,
in_gemmk_gemmn0_gemmn1_grid_
s
te
p
_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
,
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
s
te
p
_hacks
,
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
View file @
4e57b30a
...
@@ -155,7 +155,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
...
@@ -155,7 +155,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM1
...
@@ -165,7 +165,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
...
@@ -165,7 +165,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GemmM1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GemmM1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN1
...
@@ -175,7 +175,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
...
@@ -175,7 +175,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
=
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmM0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmM0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM10
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM10
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM11
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM11
...
@@ -189,10 +189,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
...
@@ -189,10 +189,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 4-: GemmN10
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 4-: GemmN10
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 5-: GemmN11
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 5-: GemmN11
constexpr
auto
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
...
@@ -231,22 +231,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
...
@@ -231,22 +231,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
5
,
// CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11
,
GemmCThreadTransferDstScalarPerVector_N11
,
decltype
(
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
)
>
(
decltype
(
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
)
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
i
te
rator
_hacks
,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
s
te
p
_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
View file @
4e57b30a
...
@@ -92,12 +92,12 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
...
@@ -92,12 +92,12 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
...
@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
constexpr
auto
out_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -123,10 +123,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
...
@@ -123,10 +123,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
...
@@ -167,22 +167,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
...
@@ -167,22 +167,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
7
,
7
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
false
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
false
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
wei_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
out_m0_m1_m2_n_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment